00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef _PASSENGER_EVENTED_MESSAGE_SERVER_H_
00026 #define _PASSENGER_EVENTED_MESSAGE_SERVER_H_
00027
00028 #include <boost/shared_ptr.hpp>
00029 #include <ev++.h>
00030 #include <cstdarg>
00031 #include "EventedServer.h"
00032 #include "MessageReadersWriters.h"
00033 #include "AccountsDatabase.h"
00034 #include "Constants.h"
00035 #include "Utils.h"
00036
00037 namespace Passenger {
00038
00039 using namespace boost;
00040
00041
00042
00043 struct EventedMessageClientContext {
00044 enum State {
00045 MS_READING_USERNAME,
00046 MS_READING_PASSWORD,
00047 MS_READING_MESSAGE,
00048 MS_PROCESSING_MESSAGE
00049 };
00050
00051 State state;
00052 AccountPtr account;
00053
00054 ev::timer authenticationTimer;
00055 ScalarMessage scalarReader;
00056 ArrayMessage arrayReader;
00057 string username;
00058
00059 EventedMessageClientContext() {
00060 state = MS_READING_USERNAME;
00061 }
00062
00063 ~EventedMessageClientContext() {
00064
00065
00066 scalarReader.reset(true);
00067 }
00068 };
00069
00070 class EventedMessageClient: public EventedClient {
00071 public:
00072 EventedMessageClientContext messageServer;
00073
00074 EventedMessageClient(struct ev_loop *loop, const FileDescriptor &fd)
00075 : EventedClient(loop, fd)
00076 {
00077 messageServer.authenticationTimer.set(loop);
00078 }
00079
00080 void writeArrayMessage(const char *name, ...) {
00081 va_list ap;
00082 unsigned int count = 0;
00083
00084 va_start(ap, name);
00085 while (va_arg(ap, const char *) != NULL) {
00086 count++;
00087 }
00088 va_end(ap);
00089
00090 StaticString args[count + 1];
00091 unsigned int i = 1;
00092
00093 args[0] = name;
00094 va_start(ap, name);
00095 while (true) {
00096 const char *arg = va_arg(ap, const char *);
00097 if (arg != NULL) {
00098 args[i] = arg;
00099 i++;
00100 } else {
00101 break;
00102 }
00103 }
00104 va_end(ap);
00105
00106 writeArrayMessage(args, count + 1);
00107 }
00108
00109 void writeArrayMessage(StaticString args[], unsigned int count) {
00110 char headerBuf[sizeof(uint16_t)];
00111 unsigned int outSize = ArrayMessage::outputSize(count);
00112 StaticString out[outSize];
00113
00114 ArrayMessage::generate(args, count, headerBuf, out, outSize);
00115 write(out, outSize);
00116 }
00117 };
00118
00119
00120
00121
00122
00123
00124 class EventedMessageServer: public EventedServer {
00125 protected:
00126 AccountsDatabasePtr accountsDatabase;
00127
00128
00129
00130
00131 virtual EventedClient *createClient(const FileDescriptor &fd) {
00132 return new EventedMessageClient(getLoop(), fd);
00133 }
00134
00135 virtual void onNewClient(EventedClient *_client) {
00136 EventedMessageClient *client = (EventedMessageClient *) _client;
00137 EventedMessageClientContext *context = &client->messageServer;
00138
00139 context->authenticationTimer.set
00140 <&EventedMessageServer::onAuthenticationTimeout>(client);
00141 context->authenticationTimer.start(10);
00142
00143 context->arrayReader.reserve(5);
00144 context->scalarReader.setMaxSize(MESSAGE_SERVER_MAX_USERNAME_SIZE);
00145
00146 client->writeArrayMessage("version", protocolVersion(), NULL);
00147 }
00148
00149 virtual void onClientReadable(EventedClient *_client) {
00150 EventedMessageClient *client = (EventedMessageClient *) _client;
00151 this_thread::disable_syscall_interruption dsi;
00152 int i = 0;
00153 bool done = false;
00154
00155
00156
00157 while (i < 10 && !done) {
00158 char buf[1024 * 8];
00159 ssize_t ret;
00160
00161 ret = syscalls::read(client->fd, buf, sizeof(buf));
00162 if (ret == -1) {
00163 if (errno != EAGAIN) {
00164 int e = errno;
00165 client->disconnect(true);
00166 logSystemError(client, "Cannot read data from client", e);
00167 }
00168 done = true;
00169 } else if (ret == 0) {
00170 done = true;
00171 ScopeGuard guard(boost::bind(&EventedClient::disconnect,
00172 client, false));
00173 onEndOfStream(client);
00174 } else {
00175 onDataReceived(client, buf, ret);
00176 }
00177 i++;
00178 done = done || !client->ioAllowed();
00179 }
00180 }
00181
00182
00183
00184
00185 virtual void onClientAuthenticated(EventedMessageClient *client) {
00186
00187 }
00188
00189 virtual bool onMessageReceived(EventedMessageClient *client, const vector<StaticString> &args) {
00190 return true;
00191 }
00192
00193 virtual void onEndOfStream(EventedMessageClient *client) {
00194
00195 }
00196
00197 virtual pair<size_t, bool> onOtherDataReceived(EventedMessageClient *client,
00198 const char *data, size_t size)
00199 {
00200 abort();
00201 }
00202
00203 virtual const char *protocolVersion() const {
00204 return "1";
00205 }
00206
00207 void discardReadData() {
00208 readDataDiscarded = true;
00209 }
00210
00211 private:
00212 bool readDataDiscarded;
00213
00214 static void onAuthenticationTimeout(ev::timer &t, int revents) {
00215 EventedMessageClient *client = (EventedMessageClient *) t.data;
00216 client->disconnect();
00217 }
00218
00219 void onDataReceived(EventedMessageClient *client, char *data, size_t size) {
00220 EventedMessageClientContext *context = &client->messageServer;
00221 size_t consumed = 0;
00222
00223 readDataDiscarded = false;
00224 while (consumed < size && client->ioAllowed() && !readDataDiscarded) {
00225 char *current = data + consumed;
00226 size_t rest = size - consumed;
00227
00228 switch (context->state) {
00229 case EventedMessageClientContext::MS_READING_USERNAME:
00230 consumed += context->scalarReader.feed(current, rest);
00231 if (context->scalarReader.hasError()) {
00232 client->writeArrayMessage(
00233 "The supplied username is too long.",
00234 NULL);
00235 client->disconnect();
00236 } else if (context->scalarReader.done()) {
00237 context->username = context->scalarReader.value();
00238 context->scalarReader.reset();
00239 context->scalarReader.setMaxSize(MESSAGE_SERVER_MAX_PASSWORD_SIZE);
00240 context->state = EventedMessageClientContext::MS_READING_PASSWORD;
00241 }
00242 break;
00243
00244 case EventedMessageClientContext::MS_READING_PASSWORD: {
00245 size_t locallyConsumed;
00246
00247 locallyConsumed = context->scalarReader.feed(current, rest);
00248 consumed += locallyConsumed;
00249
00250
00251
00252 MemZeroGuard passwordGuard(current, locallyConsumed);
00253
00254 if (context->scalarReader.hasError()) {
00255 context->scalarReader.reset(true);
00256 client->writeArrayMessage(
00257 "The supplied password is too long.",
00258 NULL);
00259 client->disconnect();
00260 } else if (context->scalarReader.done()) {
00261 context->authenticationTimer.stop();
00262 context->account = accountsDatabase->authenticate(
00263 context->username, context->scalarReader.value());
00264 passwordGuard.zeroNow();
00265 context->username.clear();
00266 if (context->account) {
00267 context->scalarReader.reset(true);
00268 context->state = EventedMessageClientContext::MS_READING_MESSAGE;
00269 client->writeArrayMessage("ok", NULL);
00270 onClientAuthenticated(client);
00271 } else {
00272 context->scalarReader.reset(true);
00273 client->writeArrayMessage(
00274 "Invalid username or password.",
00275 NULL);
00276 client->disconnect();
00277 }
00278 }
00279 break;
00280 }
00281
00282 case EventedMessageClientContext::MS_READING_MESSAGE:
00283 consumed += context->arrayReader.feed(current, rest);
00284 if (context->arrayReader.hasError()) {
00285 client->disconnect();
00286 } else if (context->arrayReader.done()) {
00287 context->state = EventedMessageClientContext::MS_PROCESSING_MESSAGE;
00288 if (context->arrayReader.value().empty()) {
00289 logError(client, "Client sent an empty message.");
00290 client->disconnect();
00291 } else if (onMessageReceived(client, context->arrayReader.value())
00292 && context->state == EventedMessageClientContext::MS_PROCESSING_MESSAGE) {
00293 context->state = EventedMessageClientContext::MS_READING_MESSAGE;
00294 }
00295 context->arrayReader.reset();
00296 }
00297 break;
00298
00299 case EventedMessageClientContext::MS_PROCESSING_MESSAGE: {
00300 pair<size_t, bool> ret = onOtherDataReceived(client, current, rest);
00301 consumed += ret.first;
00302 if (ret.second && context->state == EventedMessageClientContext::MS_PROCESSING_MESSAGE) {
00303 context->state = EventedMessageClientContext::MS_READING_MESSAGE;
00304 }
00305 break;
00306 }
00307
00308 default:
00309
00310 abort();
00311 }
00312 }
00313 }
00314
00315 public:
00316 EventedMessageServer(struct ev_loop *loop, FileDescriptor fd,
00317 const AccountsDatabasePtr &accountsDatabase)
00318 : EventedServer(loop, fd)
00319 {
00320 this->accountsDatabase = accountsDatabase;
00321 }
00322 };
00323
00324
00325 }
00326
00327 #endif