00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef _PASSENGER_MESSAGE_CLIENT_H_
00026 #define _PASSENGER_MESSAGE_CLIENT_H_
00027
00028 #include <boost/shared_ptr.hpp>
00029 #include <string>
00030
00031 #include "StaticString.h"
00032 #include "MessageChannel.h"
00033 #include "Exceptions.h"
00034 #include "Utils/IOUtils.h"
00035
00036
00037 namespace Passenger {
00038
00039 using namespace std;
00040 using namespace boost;
00041
00042 class MessageClient {
00043 protected:
00044 FileDescriptor fd;
00045 MessageChannel channel;
00046 bool shouldAutoDisconnect;
00047
00048
00049
00050 virtual void sendUsername(MessageChannel &channel, const string &username) {
00051 channel.writeScalar(username);
00052 }
00053
00054 virtual void sendPassword(MessageChannel &channel, const StaticString &userSuppliedPassword) {
00055 channel.writeScalar(userSuppliedPassword.c_str(), userSuppliedPassword.size());
00056 }
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 void authenticate(const string &username, const StaticString &userSuppliedPassword) {
00068 vector<string> args;
00069
00070 sendUsername(channel, username);
00071 sendPassword(channel, userSuppliedPassword);
00072
00073 if (!channel.read(args)) {
00074 throw IOException("The message server did not send an authentication response.");
00075 } else if (args.size() != 1) {
00076 throw IOException("The authentication response that the message server sent is not valid.");
00077 } else if (args[0] != "ok") {
00078 throw SecurityException("The message server denied authentication: " + args[0]);
00079 }
00080 }
00081
00082 void autoDisconnect() {
00083 if (shouldAutoDisconnect) {
00084
00085 fd = FileDescriptor();
00086 channel = MessageChannel();
00087 }
00088 }
00089
00090 public:
00091
00092
00093
00094
00095 MessageClient() {
00096
00097
00098
00099
00100 shouldAutoDisconnect = true;
00101 }
00102
00103 virtual ~MessageClient() { }
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125 MessageClient *connect(const string &serverAddress, const string &username,
00126 const StaticString &userSuppliedPassword)
00127 {
00128 TRACE_POINT();
00129 try {
00130 fd = connectToServer(serverAddress.c_str());
00131 channel = MessageChannel(fd);
00132
00133 vector<string> args;
00134 if (!read(args)) {
00135 throw IOException("The message server closed the connection before sending a version identifier.");
00136 }
00137 if (args.size() != 2 || args[0] != "version") {
00138 throw IOException("The message server didn't sent a valid version identifier.");
00139 }
00140 if (args[1] != "1") {
00141 string message = string("Unsupported message server protocol version ") +
00142 args[1] + ".";
00143 throw IOException(message);
00144 }
00145
00146 authenticate(username, userSuppliedPassword);
00147 return this;
00148 } catch (const RuntimeException &) {
00149 autoDisconnect();
00150 throw;
00151 } catch (const SystemException &) {
00152 autoDisconnect();
00153 throw;
00154 } catch (const IOException &) {
00155 autoDisconnect();
00156 throw;
00157 } catch (const boost::thread_interrupted &) {
00158 autoDisconnect();
00159 throw;
00160 }
00161 }
00162
00163 void disconnect() {
00164 fd.close();
00165 fd = FileDescriptor();
00166 channel = MessageChannel();
00167 }
00168
00169 bool connected() const {
00170 return fd != -1;
00171 }
00172
00173 void setAutoDisconnect(bool value) {
00174 shouldAutoDisconnect = value;
00175 }
00176
00177 FileDescriptor getConnection() const {
00178 return fd;
00179 }
00180
00181
00182
00183
00184
00185 bool read(vector<string> &args) {
00186 try {
00187 return channel.read(args);
00188 } catch (const SystemException &e) {
00189 autoDisconnect();
00190 throw;
00191 }
00192 }
00193
00194
00195
00196
00197
00198
00199
00200 bool readScalar(string &output, unsigned int maxSize = 0, unsigned long long *timeout = NULL) {
00201 try {
00202 return channel.readScalar(output, maxSize, timeout);
00203 } catch (const SystemException &) {
00204 autoDisconnect();
00205 throw;
00206 } catch (const SecurityException &) {
00207 autoDisconnect();
00208 throw;
00209 } catch (const TimeoutException &) {
00210 autoDisconnect();
00211 throw;
00212 }
00213 }
00214
00215
00216
00217
00218
00219
00220 int readFileDescriptor(bool negotiate = true) {
00221 try {
00222 return channel.readFileDescriptor(negotiate);
00223 } catch (const SystemException &) {
00224 autoDisconnect();
00225 throw;
00226 } catch (const IOException &) {
00227 autoDisconnect();
00228 throw;
00229 }
00230 }
00231
00232
00233
00234
00235
00236 void write(const char *name, ...) {
00237 va_list ap;
00238 va_start(ap, name);
00239 try {
00240 try {
00241 channel.write(name, ap);
00242 } catch (const SystemException &) {
00243 autoDisconnect();
00244 throw;
00245 }
00246 va_end(ap);
00247 } catch (...) {
00248 va_end(ap);
00249 throw;
00250 }
00251 }
00252
00253
00254
00255
00256
00257 void writeScalar(const char *data, unsigned int size) {
00258 try {
00259 channel.writeScalar(data, size);
00260 } catch (const SystemException &) {
00261 autoDisconnect();
00262 throw;
00263 }
00264 }
00265
00266
00267
00268
00269
00270 void writeScalar(const StaticString &data) {
00271 try {
00272 channel.writeScalar(data.c_str(), data.size());
00273 } catch (const SystemException &) {
00274 autoDisconnect();
00275 throw;
00276 }
00277 }
00278
00279
00280
00281
00282
00283 void writeFileDescriptor(int fileDescriptor, bool negotiate = true) {
00284 try {
00285 channel.writeFileDescriptor(fileDescriptor, negotiate);
00286 } catch (const SystemException &) {
00287 autoDisconnect();
00288 throw;
00289 }
00290 }
00291 };
00292
00293 typedef shared_ptr<MessageClient> MessageClientPtr;
00294
00295 }
00296
00297 #endif