00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00033
00034
00035
00036
00037
00038
00039 #include <zmq.hpp>
00040 #include "async-worker.h"
00041 #include "OpenThreads/Thread"
00042 #include <vector>
00043
00044 static zmq::context_t zmqContext(1) ;
00045 static const char zmqWorkSocketName[] = "inproc://async-work" ;
00046 static const char zmqDoneSocketName[] = "inproc://async-done" ;
00047
00049 namespace async
00050 {
00051
00057
00058 class WorkerThread : public OpenThreads::Thread
00059 {
00060 public:
00062 WorkerThread() {}
00063
00064 public:
00070 virtual void run()
00071 {
00072 zmq::socket_t inSocket(zmqContext, ZMQ_UPSTREAM) ;
00073 zmq::socket_t outSocket(zmqContext, ZMQ_DOWNSTREAM) ;
00074 inSocket.connect(zmqWorkSocketName) ;
00075 outSocket.connect(zmqDoneSocketName) ;
00076
00077 while ( true )
00078 {
00079 zmq::message_t msg ;
00080
00081 if ( !inSocket.recv(&msg, 0) )
00082 {
00083 return ;
00084 }
00085
00086 if ( msg.size() == 0 || msg.data() == 0 )
00087 {
00088 return ;
00089 }
00090
00091 async::workload_t* const instance = *(async::workload_t**)(msg.data()) ;
00092 if ( instance == NULL )
00093 continue ;
00094
00095
00096 const bool hasResults = instance->HasResults() ;
00097 instance->Work() ;
00098
00099
00100 if ( hasResults )
00101 {
00102 outSocket.send(msg, 0) ;
00103 }
00104 else
00105 instance->Destroy() ;
00106 }
00107 }
00108 } ;
00109
00113
00114 class WorkerPool
00115 {
00116 public:
00118 WorkerPool(size_t numThreads)
00119 : m_threads()
00120 , m_sendSocket(zmqContext, ZMQ_DOWNSTREAM)
00121 , m_recvSocket(zmqContext, ZMQ_UPSTREAM)
00122 {
00123 assert( numThreads > 0 ) ;
00124
00125 const size_t maxThreads = OpenThreads::GetNumberOfProcessors() * 2 ;
00126 if ( numThreads > maxThreads )
00127 numThreads = maxThreads ;
00128
00129 m_sendSocket.bind(zmqWorkSocketName) ;
00130 m_recvSocket.bind(zmqDoneSocketName) ;
00131
00132
00133 for ( size_t i = 0 ; i < numThreads ; ++i )
00134 {
00135 m_threads.push_back(new WorkerThread) ;
00136 m_threads.back()->start() ;
00137 }
00138 }
00139
00141 ~WorkerPool()
00142 {
00143
00144 for ( size_t i = 0 ; i < m_threads.size() ; ++i )
00145 {
00146 zmq::message_t emptyMsg ;
00147 m_sendSocket.send(emptyMsg, ZMQ_NOBLOCK) ;
00148 OpenThreads::Thread::YieldCurrentThread() ;
00149 }
00150
00151 for ( size_t i = 0 ; i < m_threads.size() ; ++i )
00152 {
00153 m_threads[i]->cancel() ;
00154 }
00155
00156 OpenThreads::Thread::YieldCurrentThread() ;
00157 }
00158
00159 private:
00160 std::vector<OpenThreads::Thread*> m_threads ;
00161
00162 private:
00163 zmq::socket_t m_sendSocket ;
00164 zmq::socket_t m_recvSocket ;
00165 size_t pendingResults ;
00166
00167 public:
00173 bool Send(const async::workload_t* const instancePtr)
00174 {
00175 const bool expectResults = instancePtr->HasResults() ;
00176
00177
00178
00179 zmq::message_t msg(sizeof(instancePtr)) ;
00180 *(const async::workload_t**)msg.data() = instancePtr ;
00181 if ( !m_sendSocket.send(msg, 0) )
00182 return false ;
00183
00184 if ( expectResults )
00185 ++pendingResults ;
00186
00187 return true ;
00188 }
00189
00192 void Recv()
00193 {
00194 zmq::message_t msg ;
00195 if ( !m_recvSocket.recv(&msg, 0) )
00196 return ;
00197
00198
00199 async::workload_t* const instancePtr = *(async::workload_t**)(msg.data()) ;
00200 if ( instancePtr != NULL )
00201 {
00202 instancePtr->Result() ;
00203 instancePtr->Destroy() ;
00204 }
00205 --pendingResults ;
00206 }
00207
00209 bool GetResults()
00210 {
00211 if ( pendingResults == 0 )
00212 return false ;
00213
00214 while ( pendingResults > 0 )
00215 {
00216 Recv() ;
00217 }
00218 return true ;
00219 }
00220
00222 size_t PendingResults() const
00223 {
00224 return pendingResults ;
00225 }
00226 } ;
00227
00229
00230
00231
00232
00233 static WorkerPool& getWorkerPool()
00234 {
00235 static WorkerPool workerPool(OpenThreads::GetNumberOfProcessors()) ;
00236 return workerPool ;
00237 }
00238
00240
00241
00242
00243 void Queue(const async::workload_t* const workload)
00244 {
00245 getWorkerPool().Send(workload) ;
00246 }
00247
00253
00254 bool GetResults()
00255 {
00256 return getWorkerPool().GetResults() ;
00257 }
00258
00261
00262 size_t PendingResults()
00263 {
00264 return getWorkerPool().PendingResults() ;
00265 }
00266
00267 }
00268