diff --git a/connection/connection.cpp b/connection/connection.cpp index fb03bf9..b7cd13a 100644 --- a/connection/connection.cpp +++ b/connection/connection.cpp @@ -1,901 +1,947 @@ /* Copyright (C) 2013 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #include "connection.h" #include "connection_p.h" #include "arguments.h" #include "authclient.h" #include "event.h" #include "eventdispatcher_p.h" #include "icompletionlistener.h" #include "iconnectionstatelistener.h" #include "imessagereceiver.h" #include "iserver.h" #include "localsocket.h" #include "message.h" #include "message_p.h" #include "pendingreply.h" #include "pendingreply_p.h" #include "stringtools.h" #include #include class HelloReceiver : public IMessageReceiver { public: void handlePendingReplyFinished(PendingReply *pr, Connection *) override { assert(pr == &m_helloReply); (void) pr; m_parent->handleHelloReply(); } PendingReply m_helloReply; // keep it here so it conveniently goes away when it's done ConnectionPrivate *m_parent; }; class ClientConnectedHandler : public ICompletionListener { public: ~ClientConnectedHandler() override { delete m_server; } void handleCompletion(void *) override { m_parent->handleClientConnected(); } IServer *m_server; ConnectionPrivate *m_parent; }; static Connection::State userState(ConnectionPrivate::State ps) { switch (ps) { case ConnectionPrivate::Unconnected: return Connection::Unconnected; case ConnectionPrivate::ServerWaitingForClient: case ConnectionPrivate::Authenticating: case ConnectionPrivate::AwaitingUniqueName: return Connection::Connecting; case ConnectionPrivate::Connected: return Connection::Connected; } assert(false); return Connection::Unconnected; } ConnectionStateChanger::ConnectionStateChanger(ConnectionPrivate *cp) : m_connPrivate(cp) { } ConnectionStateChanger::ConnectionStateChanger(ConnectionPrivate *cp, ConnectionPrivate::State newState) : m_connPrivate(cp), m_oldState(cp->m_state) { cp->m_state = newState; } ConnectionStateChanger::~ConnectionStateChanger() { if (m_oldState < 0) { return; } const Connection::State oldUserState = userState(static_cast(m_oldState)); const Connection::State newUserState = userState(m_connPrivate->m_state); if (oldUserState != newUserState) { m_connPrivate->notifyStateChange(oldUserState, newUserState); } } void ConnectionStateChanger::setNewState(ConnectionPrivate::State newState) { // Ensure that, in the destructor, the new state is always compared to the original old state if (m_oldState < 0) { m_oldState = m_connPrivate->m_state; } m_connPrivate->m_state = newState; } void ConnectionStateChanger::disable() { m_oldState = -1; } ConnectionPrivate::ConnectionPrivate(Connection *connection, EventDispatcher *dispatcher) : IIoEventForwarder(EventDispatcherPrivate::get(dispatcher)), m_connection(connection), m_eventDispatcher(dispatcher) { } IO::Status ConnectionPrivate::handleIoReady(IO::RW rw) { IO::Status status; IIoEventListener *const downstream = downstreamListener(); if (m_state == ServerWaitingForClient) { assert(downstream == m_clientConnectedHandler->m_server); } else { assert(downstream == m_transport); } if (downstream) { status = downstream->handleIoReady(rw); } else { status = IO::Status::InternalError; } ConnectionStateChanger stateChanger(this); if (status != IO::Status::OK) { - stateChanger.setNewState(ConnectionPrivate::Unconnected); - close(Error::RemoteDisconnect); + if (status != IO::Status::PayloadError) { + stateChanger.setNewState(ConnectionPrivate::Unconnected); + close(Error::RemoteDisconnect); + } else { + assert(!m_sendQueue.empty()); + const Message &msg = m_sendQueue.front(); + uint32 failedSerial = msg.serial(); + Error error = msg.error(); + m_sendQueue.pop_front(); + // If the following fails, there is no "spontaneously failed to send" notification mechanism. + // It is not a mistake in this case that it fails silently. + maybeDispatchToPendingReply(failedSerial, error); + } } return status; } Connection::Connection(EventDispatcher *dispatcher, const ConnectAddress &ca) : d(new ConnectionPrivate(this, dispatcher)) { d->m_connectAddress = ca; assert(d->m_eventDispatcher); EventDispatcherPrivate::get(d->m_eventDispatcher)->m_connectionToNotify = d; if (ca.type() == ConnectAddress::Type::None || ca.role() == ConnectAddress::Role::None) { return; } ConnectionStateChanger stateChanger(d); if (ca.role() == ConnectAddress::Role::PeerServer) { // this sets up a server that will be destroyed after accepting exactly one connection d->m_clientConnectedHandler = new ClientConnectedHandler; ConnectAddress dummyClientAddress; IServer *const is = IServer::create(ca, &dummyClientAddress); d->addIoListener(is); is->setNewConnectionListener(d->m_clientConnectedHandler); d->m_clientConnectedHandler->m_server = is; d->m_clientConnectedHandler->m_parent = d; stateChanger.setNewState(ConnectionPrivate::ServerWaitingForClient); } else { d->m_transport = ITransport::create(ca); d->addIoListener(d->m_transport); if (ca.role() == ConnectAddress::Role::BusClient) { d->startAuthentication(); stateChanger.setNewState(ConnectionPrivate::Authenticating); } else { assert(ca.role() == ConnectAddress::Role::PeerClient); // get ready to receive messages right away d->receiveNextMessage(); stateChanger.setNewState(ConnectionPrivate::Connected); } } } Connection::Connection(EventDispatcher *dispatcher, CommRef mainConnectionRef) : d(new ConnectionPrivate(this, dispatcher)) { EventDispatcherPrivate::get(d->m_eventDispatcher)->m_connectionToNotify = d; + // This must be destroyed after all the Lockers so we notify with no locks held! + ConnectionStateChanger stateChanger(d); + d->m_mainThreadLink = std::move(mainConnectionRef.commutex); CommutexLocker locker(&d->m_mainThreadLink); assert(locker.hasLock()); Commutex *const id = d->m_mainThreadLink.id(); if (!id) { assert(false); return; // stay in Unconnected state } - // TODO how do we handle m_state? - d->m_mainThreadConnection = mainConnectionRef.connection; ConnectionPrivate *mainD = d->m_mainThreadConnection; // get the current values - if we got them from e.g. the CommRef they could be outdated // and we don't want to wait for more event ping-pong SpinLocker mainLocker(&mainD->m_lock); d->m_connectAddress = mainD->m_connectAddress; // register with the main Connection SecondaryConnectionConnectEvent *evt = new SecondaryConnectionConnectEvent(); evt->connection = d; evt->id = id; EventDispatcherPrivate::get(mainD->m_eventDispatcher) ->queueEvent(std::unique_ptr(evt)); + stateChanger.setNewState(ConnectionPrivate::AwaitingUniqueName); } Connection::Connection(ITransport *transport, EventDispatcher *ed, const ConnectAddress &address) : d(new ConnectionPrivate(this, ed)) { // TODO FULLY validate address, also in the other constructors and in ITransport::create() // and in IServer::create()! assert(address.role() == ConnectAddress::Role::PeerServer); assert(d->m_eventDispatcher); d->m_transport = transport; d->addIoListener(d->m_transport); d->m_connectAddress = address; EventDispatcherPrivate::get(d->m_eventDispatcher)->m_connectionToNotify = d; #if 0 // TODO make the client authenticate itself, roughly along these lines // (not yet investigated whether peer auth is out of spec, optional or mandatory) // this sets up a server that will be destroyed after accepting exactly one connection d->m_clientConnectedHandler = new ClientConnectedHandler; d->m_clientConnectedHandler->m_server = IServer::create(ca); d->m_clientConnectedHandler->m_server->setEventDispatcher(dispatcher); d->m_clientConnectedHandler->m_server->setNewConnectionListener(d->m_clientConnectedHandler); d->m_clientConnectedHandler->m_parent = d; #endif d->receiveNextMessage(); ConnectionStateChanger stateChanger(d, ConnectionPrivate::Connected); } Connection::Connection(Connection &&other) { d = other.d; other.d = nullptr; if (d) { d->m_connection = this; } } Connection &Connection::operator=(Connection &&other) { this->~Connection(); d = other.d; other.d = nullptr; if (d) { d->m_connection = this; } return *this; } Connection::~Connection() { if (!d) { return; } d->close(Error::LocalDisconnect); delete d->m_transport; delete d->m_authClient; delete d->m_helloReceiver; delete d->m_receivingMessage; delete d; d = nullptr; } Connection::State Connection::state() const { return userState(d->m_state); } void Connection::close() { d->close(Error::LocalDisconnect); } void ConnectionPrivate::close(Error withError) { // Can't be main and secondary at the main time - it could be made to work, but what for? assert(m_secondaryThreadLinks.empty() || !m_mainThreadConnection); if (m_mainThreadConnection) { CommutexUnlinker unlinker(&m_mainThreadLink); if (unlinker.hasLock()) { SecondaryConnectionDisconnectEvent *evt = new SecondaryConnectionDisconnectEvent(); evt->connection = this; EventDispatcherPrivate::get(m_mainThreadConnection->m_eventDispatcher) ->queueEvent(std::unique_ptr(evt)); } } // Destroy whatever is suitable and available at a given time, in order to avoid things like // one secondary thread blocking another indefinitely and smaller dependency-related slowdowns. while (!m_secondaryThreadLinks.empty()) { for (auto it = m_secondaryThreadLinks.begin(); it != m_secondaryThreadLinks.end(); ) { CommutexUnlinker unlinker(&it->second, false); if (unlinker.willSucceed()) { if (unlinker.hasLock()) { MainConnectionDisconnectEvent *evt = new MainConnectionDisconnectEvent(); evt->error = withError; EventDispatcherPrivate::get(it->first->m_eventDispatcher) ->queueEvent(std::unique_ptr(evt)); } unlinker.unlinkNow(); // don't access the element after erasing it, finish it now it = m_secondaryThreadLinks.erase(it); } else { ++it; // don't block, try again next iteration } } } cancelAllPendingReplies(withError); EventDispatcherPrivate::get(m_eventDispatcher)->m_connectionToNotify = nullptr; if (m_transport) { m_transport->close(); } ConnectionStateChanger stateChanger(this, Unconnected); } void ConnectionPrivate::startAuthentication() { // Reserve serial 1 for the "hello" message - technically not necessary, there is no required ordering // of serials. takeNextSerial(); m_authClient = new AuthClient(m_transport); m_authClient->setCompletionListener(this); } void ConnectionPrivate::handleHelloReply() { ConnectionStateChanger stateChanger(this); if (!m_helloReceiver->m_helloReply.hasNonErrorReply()) { delete m_helloReceiver; m_helloReceiver = nullptr; stateChanger.setNewState(Unconnected); // TODO set an error, provide access to it, also set it on messages when trying to send / receive them return; } Arguments argList = m_helloReceiver->m_helloReply.reply()->arguments(); delete m_helloReceiver; m_helloReceiver = nullptr; Arguments::Reader reader(argList); assert(reader.state() == Arguments::String); cstring busName = reader.readString(); assert(reader.state() == Arguments::Finished); m_uniqueName = toStdString(busName); // tell current secondaries UniqueNameReceivedEvent evt; evt.uniqueName = m_uniqueName; for (auto &it : m_secondaryThreadLinks) { CommutexLocker otherLocker(&it.second); if (otherLocker.hasLock()) { EventDispatcherPrivate::get(it.first->m_eventDispatcher) ->queueEvent(std::unique_ptr(new UniqueNameReceivedEvent(evt))); } } stateChanger.setNewState(Connected); } void ConnectionPrivate::notifyStateChange(Connection::State oldUserState, Connection::State newUserState) { if (m_connectionStateListener) { m_connectionStateListener->handleConnectionChanged(m_connection, oldUserState, newUserState); } } void ConnectionPrivate::handleClientConnected() { m_transport = m_clientConnectedHandler->m_server->takeNextClient(); delete m_clientConnectedHandler; m_clientConnectedHandler = nullptr; assert(m_transport); addIoListener(m_transport); receiveNextMessage(); ConnectionStateChanger stateChanger(this, Connected); } void Connection::setDefaultReplyTimeout(int msecs) { d->m_defaultTimeout = msecs; } int Connection::defaultReplyTimeout() const { return d->m_defaultTimeout; } uint32 ConnectionPrivate::takeNextSerial() { uint32 ret; do { ret = m_sendSerial.fetch_add(1, std::memory_order_relaxed); } while (unlikely(ret == 0)); return ret; } Error ConnectionPrivate::prepareSend(Message *msg) { if (msg->serial() == 0) { if (!m_mainThreadConnection) { msg->setSerial(takeNextSerial()); } else { // we take a serial from the other Connection and then serialize locally in order to keep the CPU // expense of serialization local, even though it's more complicated than doing everything in the // other thread / Connection. CommutexLocker locker(&m_mainThreadLink); if (locker.hasLock()) { msg->setSerial(m_mainThreadConnection->takeNextSerial()); } else { return Error::LocalDisconnect; } } } MessagePrivate *const mpriv = MessagePrivate::get(msg); // this is unchanged by move()ing the owning Message. if (!mpriv->serialize()) { return mpriv->m_error; } return Error::NoError; } void ConnectionPrivate::sendPreparedMessage(Message msg) { MessagePrivate *const mpriv = MessagePrivate::get(&msg); mpriv->setCompletionListener(this); m_sendQueue.push_back(std::move(msg)); if (m_state == ConnectionPrivate::Connected && m_sendQueue.size() == 1) { // first in queue, don't wait for some other event to trigger sending mpriv->send(m_transport); } } PendingReply Connection::send(Message m, int timeoutMsecs) { if (timeoutMsecs == DefaultTimeout) { timeoutMsecs = d->m_defaultTimeout; } Error error = d->prepareSend(&m); PendingReplyPrivate *pendingPriv = new PendingReplyPrivate(d->m_eventDispatcher, timeoutMsecs); pendingPriv->m_connectionOrReply.connection = d; pendingPriv->m_receiver = nullptr; pendingPriv->m_serial = m.serial(); // even if we're handing off I/O to a main Connection, keep a record because that simplifies // aborting all pending replies when we disconnect from the main Connection, no matter which // side initiated the disconnection. d->m_pendingReplies.emplace(m.serial(), pendingPriv); - if (error.isError()) { + if (error.isError() || d->m_state == ConnectionPrivate::Unconnected) { // Signal the error asynchronously, in order to get the same delayed completion callback as in // the non-error case. This should make the behavior more predictable and client code harder to // accidentally get wrong. To detect errors immediately, PendingReply::error() can be used. - pendingPriv->m_error = error; + + // An intentionally locally disconnected connection is not in an error state, but trying to send + // a message over it is an error. + pendingPriv->m_error = error.isError() ? error : Error::LocalDisconnect; pendingPriv->m_replyTimeout.start(0); } else { if (!d->m_mainThreadConnection) { d->sendPreparedMessage(std::move(m)); } else { CommutexLocker locker(&d->m_mainThreadLink); if (locker.hasLock()) { std::unique_ptr evt(new SendMessageWithPendingReplyEvent); evt->message = std::move(m); evt->connection = d; EventDispatcherPrivate::get(d->m_mainThreadConnection->m_eventDispatcher) ->queueEvent(std::move(evt)); } else { pendingPriv->m_error = Error::LocalDisconnect; } } } return PendingReply(pendingPriv); } Error Connection::sendNoReply(Message m) { // ### (when not called from send()) warn if sending a message without the noreply flag set? // doing that is wasteful, but might be common. needs investigation. Error error = d->prepareSend(&m); - if (error.isError()) { - return error; + if (error.isError() || d->m_state == ConnectionPrivate::Unconnected) { + return error.isError() ? error : Error::LocalDisconnect; } // pass ownership to the send queue now because if the IO system decided to send the message without // going through an event loop iteration, handleCompletion would be called and expects the message to // be in the queue if (!d->m_mainThreadConnection) { d->sendPreparedMessage(std::move(m)); } else { CommutexLocker locker(&d->m_mainThreadLink); if (locker.hasLock()) { std::unique_ptr evt(new SendMessageEvent); evt->message = std::move(m); EventDispatcherPrivate::get(d->m_mainThreadConnection->m_eventDispatcher) ->queueEvent(std::move(evt)); } else { return Error::LocalDisconnect; } } return Error::NoError; } size_t Connection::sendQueueLength() const { return d->m_sendQueue.size(); } void Connection::waitForConnectionEstablished() { if (d->m_state != ConnectionPrivate::Authenticating) { return; } while (d->m_state == ConnectionPrivate::Authenticating) { d->m_authClient->handleTransportCanRead(); } if (d->m_state != ConnectionPrivate::AwaitingUniqueName) { return; } // Send the hello message assert(!d->m_sendQueue.empty()); // the hello message should be in the queue MessagePrivate *helloPriv = MessagePrivate::get(&d->m_sendQueue.front()); helloPriv->handleTransportCanWrite(); // Receive the hello reply while (d->m_state == ConnectionPrivate::AwaitingUniqueName) { MessagePrivate::get(d->m_receivingMessage)->handleTransportCanRead(); } } ConnectAddress Connection::connectAddress() const { return d->m_connectAddress; } std::string Connection::uniqueName() const { return d->m_uniqueName; } bool Connection::isConnected() const { return d->m_transport && d->m_transport->isOpen(); } EventDispatcher *Connection::eventDispatcher() const { return d->m_eventDispatcher; } IMessageReceiver *Connection::spontaneousMessageReceiver() const { return d->m_client; } void Connection::setSpontaneousMessageReceiver(IMessageReceiver *receiver) { d->m_client = receiver; } IConnectionStateListener *Connection::connectionStateListener() const { return d->m_connectionStateListener; } void Connection::setConnectionStateListener(IConnectionStateListener *listener) { d->m_connectionStateListener = listener; } void ConnectionPrivate::handleCompletion(void *task) { ConnectionStateChanger stateChanger(this); switch (m_state) { case Authenticating: { assert(task == m_authClient); if (!m_authClient->isAuthenticated()) { stateChanger.setNewState(Unconnected); } delete m_authClient; m_authClient = nullptr; if (m_state == Unconnected) { break; } stateChanger.setNewState(AwaitingUniqueName); // Announce our presence to the bus and have it send some introductory information of its own Message hello = Message::createCall("/org/freedesktop/DBus", "org.freedesktop.DBus", "Hello"); hello.setSerial(1); hello.setExpectsReply(false); hello.setDestination(std::string("org.freedesktop.DBus")); MessagePrivate *const helloPriv = MessagePrivate::get(&hello); m_helloReceiver = new HelloReceiver; m_helloReceiver->m_helloReply = m_connection->send(std::move(hello)); // Small hack: Connection::send() refuses to really start sending if the connection isn't in // Connected state. So force the sending here to actually get to Connected state. helloPriv->send(m_transport); // Also ensure that the hello message is sent before any other messages that may have been // already enqueued by an API client if (m_sendQueue.size() > 1) { hello = std::move(m_sendQueue.back()); m_sendQueue.pop_back(); m_sendQueue.push_front(std::move(hello)); } m_helloReceiver->m_helloReply.setReceiver(m_helloReceiver); m_helloReceiver->m_parent = this; // get ready to receive the first message, the hello reply receiveNextMessage(); break; } case AwaitingUniqueName: // the code paths for these two states only diverge in the PendingReply handler case Connected: { assert(!m_authClient); if (!m_sendQueue.empty() && task == &m_sendQueue.front()) { m_sendQueue.pop_front(); if (!m_sendQueue.empty()) { MessagePrivate::get(&m_sendQueue.front())->send(m_transport); } } else { assert(task == m_receivingMessage); Message *const receivedMessage = m_receivingMessage; receiveNextMessage(); if (receivedMessage->type() == Message::InvalidMessage) { delete receivedMessage; } else if (!maybeDispatchToPendingReply(receivedMessage)) { if (m_client) { m_client->handleSpontaneousMessageReceived(Message(std::move(*receivedMessage)), m_connection); } // dispatch to other threads listening to spontaneous messages, if any for (auto it = m_secondaryThreadLinks.begin(); it != m_secondaryThreadLinks.end(); ) { SpontaneousMessageReceivedEvent *evt = new SpontaneousMessageReceivedEvent(); if (std::next(it) != m_secondaryThreadLinks.end()) { evt->message = *receivedMessage; } else { evt->message = std::move(*receivedMessage); } CommutexLocker otherLocker(&it->second); if (otherLocker.hasLock()) { EventDispatcherPrivate::get(it->first->m_eventDispatcher) ->queueEvent(std::unique_ptr(evt)); ++it; } else { ConnectionPrivate *connection = it->first; it = m_secondaryThreadLinks.erase(it); discardPendingRepliesForSecondaryThread(connection); delete evt; } } delete receivedMessage; } } break; } default: // ### decide what to do here break; }; } bool ConnectionPrivate::maybeDispatchToPendingReply(Message *receivedMessage) { if (receivedMessage->type() != Message::MethodReturnMessage && receivedMessage->type() != Message::ErrorMessage) { return false; } auto it = m_pendingReplies.find(receivedMessage->replySerial()); if (it == m_pendingReplies.end()) { return false; } if (PendingReplyPrivate *pr = it->second.asPendingReply()) { m_pendingReplies.erase(it); assert(!pr->m_isFinished); pr->handleReceived(receivedMessage); } else { // forward to other thread's Connection ConnectionPrivate *connection = it->second.asConnection(); m_pendingReplies.erase(it); assert(connection); PendingReplySuccessEvent *evt = new PendingReplySuccessEvent; evt->reply = std::move(*receivedMessage); delete receivedMessage; EventDispatcherPrivate::get(connection->m_eventDispatcher)->queueEvent(std::unique_ptr(evt)); } return true; } +bool ConnectionPrivate::maybeDispatchToPendingReply(uint32 serial, Error error) +{ + assert(error.isError()); + auto it = m_pendingReplies.find(serial); + if (it == m_pendingReplies.end()) { + return false; + } + + if (PendingReplyPrivate *pr = it->second.asPendingReply()) { + m_pendingReplies.erase(it); + assert(!pr->m_isFinished); + pr->handleError(error); + } else { + // forward to other thread's Connection + ConnectionPrivate *connection = it->second.asConnection(); + m_pendingReplies.erase(it); + assert(connection); + PendingReplyFailureEvent *evt = new PendingReplyFailureEvent; + evt->m_serial = serial; + evt->m_error = error; + EventDispatcherPrivate::get(connection->m_eventDispatcher)->queueEvent(std::unique_ptr(evt)); + } + return true; +} + void ConnectionPrivate::receiveNextMessage() { m_receivingMessage = new Message; MessagePrivate *const mpriv = MessagePrivate::get(m_receivingMessage); mpriv->setCompletionListener(this); mpriv->receive(m_transport); } void ConnectionPrivate::unregisterPendingReply(PendingReplyPrivate *p) { if (m_mainThreadConnection) { CommutexLocker otherLocker(&m_mainThreadLink); if (otherLocker.hasLock()) { PendingReplyCancelEvent *evt = new PendingReplyCancelEvent; evt->serial = p->m_serial; EventDispatcherPrivate::get(m_mainThreadConnection->m_eventDispatcher) ->queueEvent(std::unique_ptr(evt)); } } #ifndef NDEBUG auto it = m_pendingReplies.find(p->m_serial); assert(it != m_pendingReplies.end()); if (!m_mainThreadConnection) { assert(it->second.asPendingReply()); assert(it->second.asPendingReply() == p); } #endif m_pendingReplies.erase(p->m_serial); } void ConnectionPrivate::cancelAllPendingReplies(Error withError) { // No locking because we should have no connections to other threads anymore at this point. // No const iteration followed by container clear because that has different semantics - many // things can happen in a callback... // In case we have pending replies for secondary threads, and we cancel all pending replies, // that is because we're shutting down, which we told the secondary thread, and it will deal // with bulk cancellation of replies. We just throw away our records about them. for (auto it = m_pendingReplies.begin() ; it != m_pendingReplies.end(); ) { PendingReplyPrivate *pendingPriv = it->second.asPendingReply(); it = m_pendingReplies.erase(it); if (pendingPriv) { // if from this thread pendingPriv->handleError(withError); } } + m_sendQueue.clear(); } void ConnectionPrivate::discardPendingRepliesForSecondaryThread(ConnectionPrivate *connection) { for (auto it = m_pendingReplies.begin() ; it != m_pendingReplies.end(); ) { if (it->second.asConnection() == connection) { it = m_pendingReplies.erase(it); // notification and deletion are handled on the event's source thread } else { ++it; } } } void ConnectionPrivate::processEvent(Event *evt) { // std::cerr << "ConnectionPrivate::processEvent() with event type " << evt->type << std::endl; switch (evt->type) { case Event::SendMessage: sendPreparedMessage(std::move(static_cast(evt)->message)); break; case Event::SendMessageWithPendingReply: { SendMessageWithPendingReplyEvent *pre = static_cast(evt); m_pendingReplies.emplace(pre->message.serial(), pre->connection); sendPreparedMessage(std::move(pre->message)); break; } case Event::SpontaneousMessageReceived: if (m_client) { SpontaneousMessageReceivedEvent *smre = static_cast(evt); m_client->handleSpontaneousMessageReceived(Message(std::move(smre->message)), m_connection); } break; case Event::PendingReplySuccess: maybeDispatchToPendingReply(&static_cast(evt)->reply); break; case Event::PendingReplyFailure: { PendingReplyFailureEvent *prfe = static_cast(evt); const auto it = m_pendingReplies.find(prfe->m_serial); if (it == m_pendingReplies.end()) { // not a disaster, but when it happens in debug mode I want to check it out assert(false); break; } PendingReplyPrivate *pendingPriv = it->second.asPendingReply(); m_pendingReplies.erase(it); pendingPriv->handleError(prfe->m_error); break; } case Event::PendingReplyCancel: // This comes from a secondary thread, which handles PendingReply notification itself. m_pendingReplies.erase(static_cast(evt)->serial); break; case Event::SecondaryConnectionConnect: { SecondaryConnectionConnectEvent *sce = static_cast(evt); const auto it = find_if(m_unredeemedCommRefs.begin(), m_unredeemedCommRefs.end(), [sce](const CommutexPeer &item) { return item.id() == sce->id; } ); assert(it != m_unredeemedCommRefs.end()); const auto emplaced = m_secondaryThreadLinks.emplace(sce->connection, std::move(*it)).first; m_unredeemedCommRefs.erase(it); // "welcome package" - it's done (only) as an event to avoid locking order issues CommutexLocker locker(&emplaced->second); if (locker.hasLock()) { UniqueNameReceivedEvent *evt = new UniqueNameReceivedEvent; evt->uniqueName = m_uniqueName; EventDispatcherPrivate::get(sce->connection->m_eventDispatcher) ->queueEvent(std::unique_ptr(evt)); } break; } case Event::SecondaryConnectionDisconnect: { SecondaryConnectionDisconnectEvent *sde = static_cast(evt); // delete our records to make sure we don't call into it in the future! const auto found = m_secondaryThreadLinks.find(sde->connection); if (found == m_secondaryThreadLinks.end()) { // looks like we've noticed the disappearance of the other thread earlier return; } m_secondaryThreadLinks.erase(found); discardPendingRepliesForSecondaryThread(sde->connection); break; } case Event::MainConnectionDisconnect: { // since the main thread *sent* us the event, it already knows to drop all our PendingReplies m_mainThreadConnection = nullptr; MainConnectionDisconnectEvent *mcde = static_cast(evt); cancelAllPendingReplies(mcde->error); break; } case Event::UniqueNameReceived: // We get this when the unique name became available after we were linked up with the main thread m_uniqueName = static_cast(evt)->uniqueName; + if (m_state == AwaitingUniqueName) { + ConnectionStateChanger stateChanger(this); + stateChanger.setNewState(Connected); + } break; } } Connection::CommRef Connection::createCommRef() { // TODO this is a good time to clean up "dead" CommRefs, where the counterpart was destroyed. CommRef ret; ret.connection = d; std::pair link = CommutexPeer::createLink(); { SpinLocker mainLocker(&d->m_lock); d->m_unredeemedCommRefs.emplace_back(std::move(link.first)); } ret.commutex = std::move(link.second); return ret; } -bool Connection::supportsPassingFileDescriptors() const +uint32 Connection::supportedFileDescriptorsPerMessage() const { - return d->m_transport && d->m_transport->supportsPassingFileDescriptors(); + return d->m_transport && d->m_transport->supportedPassingUnixFdsCount(); } diff --git a/connection/connection.h b/connection/connection.h index cd77ca0..ec58d6b 100644 --- a/connection/connection.h +++ b/connection/connection.h @@ -1,131 +1,131 @@ /* Copyright (C) 2013 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #ifndef CONNECTION_H #define CONNECTION_H #include "commutex.h" #include "types.h" #include class Connection; class ConnectAddress; class ConnectionPrivate; class Error; class EventDispatcher; class IConnectionStateListener; class IMessageReceiver; class ITransport; class Message; class PendingReply; class Server; class DFERRY_EXPORT Connection { public: enum State { Unconnected = 0, Connecting, Connected }; enum ThreadAffinity { MainConnection = 0, ThreadLocalConnection }; // Reference for passing to another thread; it guarantees that the target Connection // either exists or not, but is not currently being destroyed. Yes, the data is all private. class CommRef { friend class Connection; ConnectionPrivate *connection; CommutexPeer commutex; }; // for connecting to the session or system bus Connection(EventDispatcher *dispatcher, const ConnectAddress &connectAddress); // for reusing the connection of a Connection in another thread Connection(EventDispatcher *dispatcher, CommRef otherConnection); Connection(Connection &&other); Connection &operator=(Connection &&other); ~Connection(); Connection(Connection &other) = delete; Connection &operator=(Connection &other) = delete; State state() const; void close(); CommRef createCommRef(); - bool supportsPassingFileDescriptors() const; + uint32 supportedFileDescriptorsPerMessage() const; void setDefaultReplyTimeout(int msecs); int defaultReplyTimeout() const; enum TimeoutSpecialValues { DefaultTimeout = -1, NoTimeout = -2 }; // if a message expects no reply, that is not absolutely binding; this method allows to send a message that // does not expect (request) a reply, but we get it if it comes - not terribly useful in most cases // NOTE: this takes ownership of the message! The message will be deleted after sending in some future // event loop iteration, so it is guaranteed to stay valid before the next event loop iteration. PendingReply send(Message m, int timeoutMsecs = DefaultTimeout); // Mostly same as above. // This one ignores the reply, if any. Reports any locally detectable errors in the return value. Error sendNoReply(Message m); size_t sendQueueLength() const; void waitForConnectionEstablished(); ConnectAddress connectAddress() const; std::string uniqueName() const; bool isConnected() const; EventDispatcher *eventDispatcher() const; // TODO matching patterns for subscription; note that a signal requires path, interface and // "method" (signal name) of sender void subscribeToSignal(); IMessageReceiver *spontaneousMessageReceiver() const; void setSpontaneousMessageReceiver(IMessageReceiver *receiver); IConnectionStateListener *connectionStateListener() const; void setConnectionStateListener(IConnectionStateListener *listener); private: friend class Server; // called from Server Connection(ITransport *transport, EventDispatcher *eventDispatcher, const ConnectAddress &address); friend class ConnectionPrivate; ConnectionPrivate *d; }; #endif // CONNECTION_H diff --git a/connection/connection_p.h b/connection/connection_p.h index 4f47331..fd31422 100644 --- a/connection/connection_p.h +++ b/connection/connection_p.h @@ -1,207 +1,208 @@ /* Copyright (C) 2013 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #ifndef CONNECTION_P_H #define CONNECTION_P_H #include "connection.h" #include "connectaddress.h" #include "eventdispatcher_p.h" #include "icompletionlistener.h" #include "iioeventforwarder.h" #include "spinlock.h" #include #include #include class AuthClient; class HelloReceiver; class IMessageReceiver; class ITransport; class ClientConnectedHandler; /* How to handle destruction of connected Connections Main thread connection destroyed: Need to - "cancel" registered PendingReplies from other threads (I guess also own ones, we're not doing that, I think...) - Make sure that other threads stop calling us because that's going to be a memory error when our instance has been deleted Secondary thread Connection destroyed: Need to - "cancel" PendingReplies registered in main thread - unregister from main thread as receiver of spontaneous messages because receiving events about it is going to be a memory error when our instance has been deleted Problem areas: - destroying a Connection with a locked lock (locked from another thread, obviously) - can solved by "thoroughly" disconnecting from everything before destruction - deadlocks / locking order - preliminary solution: always main Connection first, then secondary - what about the lock in EventDispatcher? - blocking: secondary blocking (as in waiting for an event - both Connections wait on *locks* of the other) on main is okay, it does that all the time anyway. main blocking on secondary is probably (not sure) not okay. Let's define some invariants: - When a Connection is destroyed, all its PendingReply instances must have been detached (completed with or without error) or destroyed. "Its" means sent through that Connection's send() method, not when a PendingReply is using the connection of the Connection but send() was called on the Connection of another thread. - When a master and a secondary connection try to communicate in any way, and the other party has been destroyed, communication will fail gracefully and there will be no crash or undefined behavior. Any pending replies that cannot finish successfully anymore will finish with an LocalDisconnect error. */ class ConnectionStateChanger; // This class sits between EventDispatcher and ITransport for I/O event forwarding purposes, // which is why it is both a listener (for EventDispatcher) and a source (mainly for ITransport) class ConnectionPrivate : public IIoEventForwarder, public ICompletionListener { public: enum State { Unconnected = 0, ServerWaitingForClient, Authenticating, AwaitingUniqueName, Connected }; static ConnectionPrivate *get(Connection *c) { return c->d; } ConnectionPrivate(Connection *connection, EventDispatcher *dispatcher); void close(Error withError); // from IIOEventInterposer IO::Status handleIoReady(IO::RW rw) override; void startAuthentication(); void handleHelloReply(); // invokes m_connectionStateListener, if any void notifyStateChange(Connection::State oldUserState, Connection::State newUserState); void handleClientConnected(); uint32 takeNextSerial(); Error prepareSend(Message *msg); void sendPreparedMessage(Message msg); void handleCompletion(void *task) override; bool maybeDispatchToPendingReply(Message *m); + bool maybeDispatchToPendingReply(uint32 serial, Error error); void receiveNextMessage(); void unregisterPendingReply(PendingReplyPrivate *p); void cancelAllPendingReplies(Error withError); void discardPendingRepliesForSecondaryThread(ConnectionPrivate *t); // For cross-thread communication between thread Connections. We could have a more complete event // system, but there is currently no need, so keep it simple and limited. void processEvent(Event *evt); // called from thread-local EventDispatcher State m_state = Unconnected; Connection *m_connection = nullptr; IMessageReceiver *m_client = nullptr; IConnectionStateListener *m_connectionStateListener = nullptr; Message *m_receivingMessage = nullptr; std::deque m_sendQueue; // waiting to be sent // only one of them can be non-null. exception: in the main thread, m_mainThreadConnection // equals this, so that the main thread knows it's the main thread and not just a thread-local // connection. ITransport *m_transport = nullptr; HelloReceiver *m_helloReceiver = nullptr; ClientConnectedHandler *m_clientConnectedHandler = nullptr; EventDispatcher *m_eventDispatcher = nullptr; ConnectAddress m_connectAddress; std::string m_uniqueName; AuthClient *m_authClient = nullptr; int m_defaultTimeout = 25000; class PendingReplyRecord { public: PendingReplyRecord(PendingReplyPrivate *pr) : isForSecondaryThread(false), ptr(pr) {} PendingReplyRecord(ConnectionPrivate *tp) : isForSecondaryThread(true), ptr(tp) {} PendingReplyPrivate *asPendingReply() const { return isForSecondaryThread ? nullptr : static_cast(ptr); } ConnectionPrivate *asConnection() const { return isForSecondaryThread ? static_cast(ptr) : nullptr; } private: bool isForSecondaryThread; void *ptr; }; std::unordered_map m_pendingReplies; // replies we're waiting for Spinlock m_lock; // only one lock because things done with lock held are quick, and anyway you shouldn't // be using one connection from multiple threads if you need best performance std::atomic m_sendSerial { 1 }; std::unordered_map m_secondaryThreadLinks; std::vector m_unredeemedCommRefs; // for createCommRef() and the constructor from CommRef ConnectionPrivate *m_mainThreadConnection = nullptr; CommutexPeer m_mainThreadLink; }; // This class helps with notifying a Connection's StateChanegListener when connection state changes. // Its benefits are: // - Tracks state changes in a few easily verified pieces of code // - Prevents crashes from the following scenario: IConnectionStateListener is notified about a change. As a // reaction, it may delete the Connection. The listener returns and control passes back into Connection // code. Connection code touches some of its (or rather, ConnectionPrivate's) data, which has been deleted // at that point. Undefined behavior ensues. // With the help of this class, the IConnectionStateListener is always called just before exit, so that no // member data can be touched afterwards. (This pattern is a good idea for almost any kind of callback.) class ConnectionStateChanger { public: ConnectionStateChanger(ConnectionPrivate *cp); ConnectionStateChanger(ConnectionPrivate *cp, ConnectionPrivate::State newState); ConnectionStateChanger(const ConnectionStateChanger &) = delete; ConnectionStateChanger &operator=(const ConnectionStateChanger &) = delete; ~ConnectionStateChanger(); void setNewState(ConnectionPrivate::State newState); void disable(); private: ConnectionPrivate *m_connPrivate; int32 m_oldState = -1; // either -1 or a valid ConnectionPrivate::State }; #endif // CONNECTION_P_H diff --git a/serialization/arguments_p.h b/serialization/arguments_p.h index 44d1c3f..26faa1e 100644 --- a/serialization/arguments_p.h +++ b/serialization/arguments_p.h @@ -1,98 +1,100 @@ /* Copyright (C) 2013 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #ifndef ARGUMENTS_P_H #define ARGUMENTS_P_H #include "arguments.h" #include "error.h" class Arguments::Private { public: Private() : m_isByteSwapped(false), m_memOwnership(nullptr) {} + static inline Private *get(Arguments *args) { return args->d; } + Private(const Private &other); Private &operator=(const Private &other); void initFrom(const Private &other); ~Private(); chunk m_data; bool m_isByteSwapped; byte *m_memOwnership; cstring m_signature; std::vector m_fileDescriptors; Error m_error; }; struct TypeInfo { inline Arguments::IoState state() const { return static_cast(_state); } byte _state; byte alignment : 6; bool isPrimitive : 1; bool isString : 1; }; // helper to verify the max nesting requirements of the d-bus spec struct Nesting { inline Nesting() : array(0), paren(0), variant(0) {} static const int arrayMax = 32; static const int parenMax = 32; static const int totalMax = 64; inline bool beginArray() { array++; return likely(array <= arrayMax && total() <= totalMax); } inline void endArray() { assert(array >= 1); array--; } inline bool beginParen() { paren++; return likely(paren <= parenMax && total() <= totalMax); } inline void endParen() { assert(paren >= 1); paren--; } inline bool beginVariant() { variant++; return likely(total() <= totalMax); } inline void endVariant() { assert(variant >= 1); variant--; } inline uint32 total() { return array + paren + variant; } uint32 array; uint32 paren; uint32 variant; }; cstring printableState(Arguments::IoState state); bool parseSingleCompleteType(cstring *s, Nesting *nest); inline bool isAligned(uint32 value, uint32 alignment) { assert(alignment == 8 || alignment == 4 || alignment == 2 || alignment == 1); return (value & (alignment - 1)) == 0; } const TypeInfo &typeInfo(char letterCode); // Macros are icky, but here every use saves three lines. // Funny condition to avoid the dangling-else problem. #define VALID_IF(cond, errCode) if (likely(cond)) {} else { \ m_state = InvalidData; d->m_error.setCode(errCode); return; } #endif // ARGUMENTS_P_H diff --git a/serialization/argumentswriter.cpp b/serialization/argumentswriter.cpp index db765ff..4ed04d6 100644 --- a/serialization/argumentswriter.cpp +++ b/serialization/argumentswriter.cpp @@ -1,1250 +1,1251 @@ /* Copyright (C) 2013 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #include "arguments.h" #include "arguments_p.h" #include "basictypeio.h" #include "malloccache.h" #include #ifdef HAVE_BOOST #include #endif enum { StructAlignment = 8 }; static constexpr byte alignLog[9] = { 0, 0, 1, 0, 2, 0, 0, 0, 3 }; inline constexpr byte alignmentLog2(uint32 alignment) { // The following is not constexpr in C++14, and it hasn't triggered in ages // assert(alignment <= 8 && (alignment < 2 || alignLog[alignment] != 0)); return alignLog[alignment]; } class Arguments::Writer::Private { public: Private() : m_signaturePosition(0), m_data(reinterpret_cast(malloc(InitialDataCapacity))), m_dataCapacity(InitialDataCapacity), m_dataPosition(SignatureReservedSpace), m_nilArrayNesting(0) { m_signature.ptr = reinterpret_cast(m_data + 1); // reserve a byte for length prefix m_signature.length = 0; } Private(const Private &other); void operator=(const Private &other); void reserveData(uint32 size, IoState *state) { if (likely(size <= m_dataCapacity)) { return; } uint32 newCapacity = m_dataCapacity; do { newCapacity *= 2; } while (size > newCapacity); byte *const oldDataPointer = m_data; m_data = reinterpret_cast(realloc(m_data, newCapacity)); m_signature.ptr += m_data - oldDataPointer; m_dataCapacity = newCapacity; // Here, we trade off getting an ArgumentsTooLong error as early as possible for fewer // conditional branches in the hot path. Because only the final message length has a well-defined // limit, and Arguments doesn't, a precise check has limited usefulness anwyay. This is just // a sanity check and out of bounds access / overflow protection. // In most cases, callers do not need to check for errors: Very large single arguments are // already rejected, and we actually allocate the too large buffer to prevent out-of-bounds // access. Any following writing API calls will then cleanly abort due to m_state == InvalidData. // ### Callers DO need to check m_state before possibly overwriting it, hiding the error! if (newCapacity > Arguments::MaxMessageLength * 3) { *state = InvalidData; m_error.setCode(Error::ArgumentsTooLong); } } bool insideVariant() { return !m_queuedData.empty(); } // We don't know how long a variant signature is when starting the variant, but we have to // insert the signature into the datastream before the data. For that reason, we need a // postprocessing pass to fix things up once the outermost variant is closed. // QueuedDataInfo stores enough information about data inside variants to be able to do // the patching up while respecting alignment and other requirements. struct QueuedDataInfo { constexpr QueuedDataInfo(byte alignment, byte size_) : alignmentExponent(alignmentLog2(alignment)), size(size_) {} byte alignment() const { return 1 << alignmentExponent; } byte alignmentExponent : 2; // powers of 2, so 1, 2, 4, 8 byte size : 6; // that's up to 63 enum SizeCode { LargestSize = 60, ArrayLengthField, ArrayLengthEndMark, VariantSignature }; }; // The parameter is not a QueuedDataInfo because the compiler doesn't seem to optimize away // QueuedDataInfo construction when insideVariant() is false, despite inlining. void maybeQueueData(byte alignment, byte size) { if (insideVariant()) { m_queuedData.push_back(QueuedDataInfo(alignment, size)); } } // Caution: does not ensure that enough space is available! void appendBulkData(chunk data) { // Align only the first of the back-to-back data chunks - otherwise, when storing values which // are 8 byte aligned, the second half of an element straddling a chunk boundary // (QueuedDataInfo::LargestSize == 60) would start at an 8-byte aligned position (so 64) // instead of 60 where we want it in order to just write a contiguous block of data. memcpy(m_data + m_dataPosition, data.ptr, data.length); m_dataPosition += data.length; if (insideVariant()) { for (uint32 l = data.length; l; ) { uint32 chunkSize = std::min(l, uint32(QueuedDataInfo::LargestSize)); m_queuedData.push_back(QueuedDataInfo(1, chunkSize)); l -= chunkSize; } } } void alignData(uint32 alignment) { if (insideVariant()) { m_queuedData.push_back(QueuedDataInfo(alignment, 0)); } zeroPad(m_data, alignment, &m_dataPosition); } uint32 m_dataElementsCountBeforeNilArray; uint32 m_dataPositionBeforeVariant; Nesting m_nesting; cstring m_signature; uint32 m_signaturePosition; byte *m_data; uint32 m_dataCapacity; uint32 m_dataPosition; int m_nilArrayNesting; std::vector m_fileDescriptors; Error m_error; enum { InitialDataCapacity = 512, // max signature length (255) + length prefix(1) + null terminator(1), rounded up to multiple of 8 // because that doesn't change alignment SignatureReservedSpace = 264 }; #ifdef WITH_DICT_ENTRY enum DictEntryState : byte { RequireBeginDictEntry = 0, InDictEntry, RequireEndDictEntry, AfterEndDictEntry }; #endif struct ArrayInfo { uint32 containedTypeBegin; // to rewind when reading the next element #ifdef WITH_DICT_ENTRY DictEntryState dictEntryState; uint32 lengthFieldPosition : 24; #else uint32 lengthFieldPosition; #endif }; struct VariantInfo { // a variant switches the currently parsed signature, so we // need to store the old signature and parse position. uint32 prevSignatureOffset; // relative to m_data uint32 prevSignaturePosition; }; struct StructInfo { uint32 containedTypeBegin; }; struct AggregateInfo { IoState aggregateType; // can be BeginArray, BeginDict, BeginStruct, BeginVariant union { ArrayInfo arr; VariantInfo var; StructInfo sct; }; }; // this keeps track of which aggregates we are currently in #ifdef HAVE_BOOST boost::container::small_vector m_aggregateStack; #else std::vector m_aggregateStack; #endif std::vector m_queuedData; }; thread_local static MallocCache allocCache; Arguments::Writer::Private::Private(const Private &other) { *this = other; } void Arguments::Writer::Private::operator=(const Private &other) { if (&other == this) { assert(false); // if this happens, the (internal) caller did something wrong return; } m_dataElementsCountBeforeNilArray = other.m_dataElementsCountBeforeNilArray; m_dataPositionBeforeVariant = other.m_dataPositionBeforeVariant; m_nesting = other.m_nesting; m_signature.ptr = other.m_signature.ptr; // ### still needs adjustment, done after allocating m_data m_signature.length = other.m_signature.length; m_signaturePosition = other.m_signaturePosition; m_dataCapacity = other.m_dataCapacity; m_dataPosition = other.m_dataPosition; // handle *m_data and the data it's pointing to m_data = reinterpret_cast(malloc(m_dataCapacity)); memcpy(m_data, other.m_data, m_dataPosition); m_signature.ptr += m_data - other.m_data; m_nilArrayNesting = other.m_nilArrayNesting; m_fileDescriptors = other.m_fileDescriptors; m_error = other.m_error; m_aggregateStack = other.m_aggregateStack; m_queuedData = other.m_queuedData; } Arguments::Writer::Writer() : d(new(allocCache.allocate()) Private), m_state(AnyData) { } Arguments::Writer::Writer(Writer &&other) : d(other.d), m_state(other.m_state), m_u(other.m_u) { other.d = nullptr; } void Arguments::Writer::operator=(Writer &&other) { if (&other == this) { return; } d = other.d; m_state = other.m_state; m_u = other.m_u; other.d = nullptr; } Arguments::Writer::Writer(const Writer &other) : d(nullptr), m_state(other.m_state), m_u(other.m_u) { if (other.d) { d = new(allocCache.allocate()) Private(*other.d); } } void Arguments::Writer::operator=(const Writer &other) { if (&other == this) { return; } m_state = other.m_state; m_u = other.m_u; if (d && other.d) { *d = *other.d; } else { Writer temp(other); std::swap(d, temp.d); } } Arguments::Writer::~Writer() { if (d) { free(d->m_data); d->m_data = nullptr; d->~Private(); allocCache.free(d); d = nullptr; } } bool Arguments::Writer::isValid() const { return !d->m_error.isError(); } Error Arguments::Writer::error() const { return d->m_error; } cstring Arguments::Writer::stateString() const { return printableState(m_state); } bool Arguments::Writer::isInsideEmptyArray() const { return d->m_nilArrayNesting > 0; } cstring Arguments::Writer::currentSignature() const { // A signature must be null-terminated to be valid. // We're only overwriting uninitialized memory, no need to undo that later. d->m_signature.ptr[d->m_signature.length] = '\0'; return d->m_signature; } uint32 Arguments::Writer::currentSignaturePosition() const { return d->m_signaturePosition; } void Arguments::Writer::doWritePrimitiveType(IoState type, uint32 alignAndSize) { d->reserveData(d->m_dataPosition + (alignAndSize << 1), &m_state); zeroPad(d->m_data, alignAndSize, &d->m_dataPosition); switch(type) { case Boolean: { uint32 num = m_u.Boolean ? 1 : 0; basic::writeUint32(d->m_data + d->m_dataPosition, num); break; } case Byte: d->m_data[d->m_dataPosition] = m_u.Byte; break; case Int16: basic::writeInt16(d->m_data + d->m_dataPosition, m_u.Int16); break; case Uint16: basic::writeUint16(d->m_data + d->m_dataPosition, m_u.Uint16); break; case Int32: basic::writeInt32(d->m_data + d->m_dataPosition, m_u.Int32); break; case Uint32: basic::writeUint32(d->m_data + d->m_dataPosition, m_u.Uint32); break; case Int64: basic::writeInt64(d->m_data + d->m_dataPosition, m_u.Int64); break; case Uint64: basic::writeUint64(d->m_data + d->m_dataPosition, m_u.Uint64); break; case Double: basic::writeDouble(d->m_data + d->m_dataPosition, m_u.Double); break; case UnixFd: { const uint32 index = d->m_fileDescriptors.size(); if (!d->m_nilArrayNesting) { d->m_fileDescriptors.push_back(m_u.Int32); } basic::writeUint32(d->m_data + d->m_dataPosition, index); break; } default: assert(false); VALID_IF(false, Error::InvalidType); } d->m_dataPosition += alignAndSize; d->maybeQueueData(alignAndSize, alignAndSize); } void Arguments::Writer::doWriteString(IoState type, uint32 lengthPrefixSize) { if (type == String) { VALID_IF(Arguments::isStringValid(cstring(m_u.String.ptr, m_u.String.length)), Error::InvalidString); } else if (type == ObjectPath) { VALID_IF(Arguments::isObjectPathValid(cstring(m_u.String.ptr, m_u.String.length)), Error::InvalidObjectPath); } else if (type == Signature) { VALID_IF(Arguments::isSignatureValid(cstring(m_u.String.ptr, m_u.String.length)), Error::InvalidSignature); } d->reserveData(d->m_dataPosition + (lengthPrefixSize << 1) + m_u.String.length + 1, &m_state); zeroPad(d->m_data, lengthPrefixSize, &d->m_dataPosition); if (lengthPrefixSize == 1) { d->m_data[d->m_dataPosition] = m_u.String.length; } else { basic::writeUint32(d->m_data + d->m_dataPosition, m_u.String.length); } d->m_dataPosition += lengthPrefixSize; d->maybeQueueData(lengthPrefixSize, lengthPrefixSize); d->appendBulkData(chunk(m_u.String.ptr, m_u.String.length + 1)); } void Arguments::Writer::advanceState(cstring signatureFragment, IoState newState) { // what needs to happen here: // - if we are in an existing portion of the signature (like writing the >1st iteration of an array) // check if the type to be written is the same as the one that's already in the signature // - otherwise we still need to check if the data we're adding conforms with the spec, e.g. // no empty structs, dict entries must have primitive key type and exactly one value type // - check well-formedness of data: strings, maximum serialized array length and message length // (variant signature length only being known after finishing a variant introduces uncertainty // of final data stream size - due to alignment padding, a variant signature longer by one can // cause an up to seven bytes longer message. in other cases it won't change message length at all.) // - increase size of data buffer when it gets too small // - store information about variants and arrays, in order to: // - know what the final binary message size will be // - in finish(), create the final data stream with inline variant signatures and array lengths if (unlikely(m_state == InvalidData)) { return; } // can't do the following because a dict is one aggregate in our counting, but two according to // the spec: an array (one) containing dict entries (two) // assert(d->m_nesting.total() == d->m_aggregateStack.size()); assert((d->m_nesting.total() == 0) == d->m_aggregateStack.empty()); m_state = AnyData; uint32 alignment = 1; bool isPrimitiveType = false; bool isStringType = false; if (signatureFragment.length) { const TypeInfo ty = typeInfo(signatureFragment.ptr[0]); alignment = ty.alignment; isPrimitiveType = ty.isPrimitive; isStringType = ty.isString; } bool isWritingSignature = d->m_signaturePosition == d->m_signature.length; if (isWritingSignature) { // signature additions must conform to syntax VALID_IF(d->m_signaturePosition + signatureFragment.length <= MaxSignatureLength, Error::SignatureTooLong); } if (!d->m_aggregateStack.empty()) { Private::AggregateInfo &aggregateInfo = d->m_aggregateStack.back(); switch (aggregateInfo.aggregateType) { case BeginVariant: // arrays and variants may contain just one single complete type; note that this will // trigger only when not inside an aggregate inside the variant or (see below) array if (d->m_signaturePosition >= 1) { VALID_IF(newState == EndVariant, Error::NotSingleCompleteTypeInVariant); } break; case BeginArray: if (d->m_signaturePosition >= aggregateInfo.arr.containedTypeBegin + 1 && newState != EndArray) { // we are not at start of contained type's signature, the array is at top of stack // -> we are at the end of the single complete type inside the array, start the next // entry. TODO: check compatibility (essentially what's in the else branch below) d->m_signaturePosition = aggregateInfo.arr.containedTypeBegin; isWritingSignature = false; } break; case BeginDict: if (d->m_signaturePosition == aggregateInfo.arr.containedTypeBegin) { #ifdef WITH_DICT_ENTRY if (aggregateInfo.arr.dictEntryState == Private::RequireBeginDictEntry) { // This is only reached immediately after beginDict() so it's kinda wasteful, oh well. VALID_IF(newState == BeginDictEntry, Error::MissingBeginDictEntry); aggregateInfo.arr.dictEntryState = Private::InDictEntry; m_state = DictKey; return; // BeginDictEntry writes no data } #endif VALID_IF(isPrimitiveType || isStringType, Error::InvalidKeyTypeInDict); } #ifdef WITH_DICT_ENTRY // TODO test this part of the state machine if (d->m_signaturePosition >= aggregateInfo.arr.containedTypeBegin + 2) { if (aggregateInfo.arr.dictEntryState == Private::RequireEndDictEntry) { VALID_IF(newState == EndDictEntry, Error::MissingEndDictEntry); aggregateInfo.arr.dictEntryState = Private::AfterEndDictEntry; m_state = BeginDictEntry; return; // EndDictEntry writes no data } else { // v should've been caught earlier assert(aggregateInfo.arr.dictEntryState == Private::AfterEndDictEntry); VALID_IF(newState == BeginDictEntry || newState == EndDict, Error::MissingBeginDictEntry); // "fall through", the rest (another iteration or finish) is handled below } } else if (d->m_signaturePosition >= aggregateInfo.arr.containedTypeBegin + 1) { assert(aggregateInfo.arr.dictEntryState == Private::InDictEntry); aggregateInfo.arr.dictEntryState = Private::RequireEndDictEntry; // Setting EndDictEntry after writing a primitive type works fine, but setting it after // ending another aggregate would be somewhat involved and need to happen somewhere // else, so just don't do that. We still produce an error when endDictEntry() is not // used correctly. // m_state = EndDictEntry; // continue and write the dict entry's value } #endif // first type has been checked already, second must be present (checked in EndDict // state handler). no third type allowed. if (d->m_signaturePosition >= aggregateInfo.arr.containedTypeBegin + 2 && newState != EndDict) { // align to dict entry d->alignData(StructAlignment); d->m_signaturePosition = aggregateInfo.arr.containedTypeBegin; isWritingSignature = false; m_state = DictKey; #ifdef WITH_DICT_ENTRY assert(newState == BeginDictEntry); aggregateInfo.arr.dictEntryState = Private::InDictEntry; return; // BeginDictEntry writes no data #endif } break; default: break; } } if (isWritingSignature) { // extend the signature for (uint32 i = 0; i < signatureFragment.length; i++) { d->m_signature.ptr[d->m_signaturePosition++] = signatureFragment.ptr[i]; } d->m_signature.length += signatureFragment.length; } else { // Do not try to prevent several iterations through a nil array. Two reasons: // - We may be writing a nil array in the >1st iteration of a non-nil outer array. // This would need to be distinguished from just iterating through a nil array // several times. Which is well possible. We don't bother with that because... // - As a QtDBus unittest illustrates, somebody may choose to serialize a fixed length // series of data elements as an array (instead of struct), so that a trivial // serialization of such data just to fill in type information in an outer empty array // would end up iterating through the inner, implicitly empty array several times. // All in all it is just not much of a benefit to be strict, so don't. //VALID_IF(likely(!d->m_nilArrayNesting), Error::ExtraIterationInEmptyArray); // signature must match first iteration (of an array/dict) VALID_IF(d->m_signaturePosition + signatureFragment.length <= d->m_signature.length, Error::TypeMismatchInSubsequentArrayIteration); // TODO need to apply special checks for state changes with no explicit signature char? // (end of array, end of variant) for (uint32 i = 0; i < signatureFragment.length; i++) { VALID_IF(d->m_signature.ptr[d->m_signaturePosition++] == signatureFragment.ptr[i], Error::TypeMismatchInSubsequentArrayIteration); } } if (isPrimitiveType) { doWritePrimitiveType(newState, alignment); return; } if (isStringType) { // In case of nil array, skip writing to make sure that the input string (which is explicitly // allowed to be garbage) is not validated and no wild pointer is dereferenced. if (likely(!d->m_nilArrayNesting)) { doWriteString(newState, alignment); } else { // The alignment of the first element in a nil array determines where array data starts, // which is needed to serialize the length correctly. Write the minimum to achieve that. // (The check to see if we're really at the first element is omitted - for performance // it's worth trying to add that check) d->alignData(alignment); } return; } Private::AggregateInfo aggregateInfo; switch (newState) { case BeginStruct: VALID_IF(d->m_nesting.beginParen(), Error::ExcessiveNesting); aggregateInfo.aggregateType = BeginStruct; aggregateInfo.sct.containedTypeBegin = d->m_signaturePosition; d->m_aggregateStack.push_back(aggregateInfo); d->alignData(alignment); break; case EndStruct: VALID_IF(!d->m_aggregateStack.empty(), Error::CannotEndStructHere); aggregateInfo = d->m_aggregateStack.back(); VALID_IF(aggregateInfo.aggregateType == BeginStruct && d->m_signaturePosition > aggregateInfo.sct.containedTypeBegin + 1, Error::EmptyStruct); // empty structs are not allowed d->m_nesting.endParen(); d->m_aggregateStack.pop_back(); break; case BeginVariant: { VALID_IF(d->m_nesting.beginVariant(), Error::ExcessiveNesting); aggregateInfo.aggregateType = BeginVariant; Private::VariantInfo &variantInfo = aggregateInfo.var; variantInfo.prevSignatureOffset = uint32(reinterpret_cast(d->m_signature.ptr) - d->m_data); d->m_signature.ptr[-1] = byte(d->m_signature.length); variantInfo.prevSignaturePosition = d->m_signaturePosition; if (!d->insideVariant()) { d->m_dataPositionBeforeVariant = d->m_dataPosition; } d->m_aggregateStack.push_back(aggregateInfo); d->m_queuedData.reserve(16); d->m_queuedData.push_back(Private::QueuedDataInfo(1, Private::QueuedDataInfo::VariantSignature)); const uint32 newDataPosition = d->m_dataPosition + Private::SignatureReservedSpace; d->reserveData(newDataPosition, &m_state); // allocate new signature in the data buffer, reserve one byte for length prefix d->m_signature.ptr = reinterpret_cast(d->m_data) + d->m_dataPosition + 1; d->m_signature.length = 0; d->m_signaturePosition = 0; d->m_dataPosition = newDataPosition; break; } case EndVariant: { VALID_IF(!d->m_aggregateStack.empty(), Error::CannotEndVariantHere); aggregateInfo = d->m_aggregateStack.back(); VALID_IF(aggregateInfo.aggregateType == BeginVariant, Error::CannotEndVariantHere); d->m_nesting.endVariant(); if (likely(!d->m_nilArrayNesting)) { // Empty variants are not allowed. As an exception, in nil arrays they are // allowed for writing a type signature like "av" in the shortest possible way. // No use adding stuff when it's not required or even possible. VALID_IF(d->m_signaturePosition > 0, Error::EmptyVariant); assert(d->m_signaturePosition <= MaxSignatureLength); // should have been caught earlier } d->m_signature.ptr[-1] = byte(d->m_signaturePosition); Private::VariantInfo &variantInfo = aggregateInfo.var; d->m_signature.ptr = reinterpret_cast(d->m_data) + variantInfo.prevSignatureOffset; d->m_signature.length = d->m_signature.ptr[-1]; d->m_signaturePosition = variantInfo.prevSignaturePosition; d->m_aggregateStack.pop_back(); // if not in any variant anymore, flush queued data and resume unqueued operation if (d->m_signature.ptr == reinterpret_cast(d->m_data) + 1) { flushQueuedData(); } break; } case BeginDict: case BeginArray: { VALID_IF(d->m_nesting.beginArray(), Error::ExcessiveNesting); if (newState == BeginDict) { // not re-opened before each element: there is no observable difference for clients VALID_IF(d->m_nesting.beginParen(), Error::ExcessiveNesting); } aggregateInfo.aggregateType = newState; aggregateInfo.arr.containedTypeBegin = d->m_signaturePosition; d->reserveData(d->m_dataPosition + (sizeof(uint32) << 1), &m_state); if (m_state == InvalidData) { break; // should be excessive length error from reserveData - do not unset error state } zeroPad(d->m_data, sizeof(uint32), &d->m_dataPosition); basic::writeUint32(d->m_data + d->m_dataPosition, 0); aggregateInfo.arr.lengthFieldPosition = d->m_dataPosition; d->m_dataPosition += sizeof(uint32); d->maybeQueueData(sizeof(uint32), Private::QueuedDataInfo::ArrayLengthField); if (newState == BeginDict) { d->alignData(StructAlignment); #ifdef WITH_DICT_ENTRY m_state = BeginDictEntry; aggregateInfo.arr.dictEntryState = Private::RequireBeginDictEntry; #else m_state = DictKey; #endif } d->m_aggregateStack.push_back(aggregateInfo); break; } case EndDict: case EndArray: { const bool isDict = newState == EndDict; VALID_IF(!d->m_aggregateStack.empty(), Error::CannotEndArrayHere); aggregateInfo = d->m_aggregateStack.back(); VALID_IF(aggregateInfo.aggregateType == (isDict ? BeginDict : BeginArray), Error::CannotEndArrayOrDictHere); VALID_IF(d->m_signaturePosition >= aggregateInfo.arr.containedTypeBegin + (isDict ? 3 : 1), Error::TooFewTypesInArrayOrDict); if (isDict) { d->m_nesting.endParen(); } d->m_nesting.endArray(); // array data starts (and in empty arrays ends) at the first array element position *after alignment* const uint32 contentAlign = isDict ? 8 : typeInfo(d->m_signature.ptr[aggregateInfo.arr.containedTypeBegin]).alignment; const uint32 arrayDataStart = align(aggregateInfo.arr.lengthFieldPosition + sizeof(uint32), contentAlign); if (unlikely(d->m_nilArrayNesting)) { if (--d->m_nilArrayNesting == 0) { d->m_dataPosition = arrayDataStart; if (d->insideVariant()) { assert(d->m_queuedData.begin() + d->m_dataElementsCountBeforeNilArray <= d->m_queuedData.end()); d->m_queuedData.erase(d->m_queuedData.begin() + d->m_dataElementsCountBeforeNilArray, d->m_queuedData.end()); assert((d->m_queuedData.end() - 2)->size == Private::QueuedDataInfo::ArrayLengthField); // align, but don't have actual data for the first element d->m_queuedData.back().size = 0; } } } // (arrange to) patch in the array length now that it is known if (d->insideVariant()) { d->m_queuedData.push_back(Private::QueuedDataInfo(1, Private::QueuedDataInfo::ArrayLengthEndMark)); } else { const uint32 arrayLength = d->m_dataPosition - arrayDataStart; VALID_IF(arrayLength <= Arguments::MaxArrayLength, Error::ArrayOrDictTooLong); basic::writeUint32(d->m_data + aggregateInfo.arr.lengthFieldPosition, arrayLength); } d->m_aggregateStack.pop_back(); break; } #ifdef WITH_DICT_ENTRY case BeginDictEntry: case EndDictEntry: break; #endif default: VALID_IF(false, Error::InvalidType); break; } } void Arguments::Writer::beginArrayOrDict(IoState beginWhat, ArrayOption option) { assert(beginWhat == BeginArray || beginWhat == BeginDict); if (unlikely(option == RestartEmptyArrayToWriteTypes)) { if (!d->m_aggregateStack.empty()) { Private::AggregateInfo &aggregateInfo = d->m_aggregateStack.back(); if (aggregateInfo.aggregateType == beginWhat) { // No writes to the array or dict may have occurred yet if (d->m_signaturePosition == aggregateInfo.arr.containedTypeBegin) { // Fix up state as if beginArray/Dict() had been called with WriteTypesOfEmptyArray // in the first place. After that small fixup we're done and return. // The code is a slightly modified version of code below under: if (isEmpty) { if (!d->m_nilArrayNesting) { d->m_nilArrayNesting = 1; d->m_dataElementsCountBeforeNilArray = d->m_queuedData.size() + 2; // +2 as below // Now correct for the elements already added in advanceState() with BeginArray / BeginDict d->m_dataElementsCountBeforeNilArray -= (beginWhat == BeginDict) ? 2 : 1; } else { // The array may be implicitly nil (so our poor API client doesn't notice) because // an array below in the aggregate stack is nil, so just allow this as a no-op. } return; } } } VALID_IF(false, Error::InvalidStateToRestartEmptyArray); } const bool isEmpty = (option != NonEmptyArray) || d->m_nilArrayNesting; if (isEmpty) { if (!d->m_nilArrayNesting++) { // For simplictiy and performance in the fast path, we keep storing the data chunks and any // variant signatures written inside an empty array. When we close the array, though, we // throw away all that data and signatures and keep only changes in the signature containing // the topmost empty array. // +2 -> keep ArrayLengthField, and first data element for alignment purposes d->m_dataElementsCountBeforeNilArray = d->m_queuedData.size() + 2; } } if (beginWhat == BeginArray) { advanceState(cstring("a", strlen("a")), beginWhat); } else { advanceState(cstring("a{", strlen("a{")), beginWhat); } } void Arguments::Writer::beginArray(ArrayOption option) { beginArrayOrDict(BeginArray, option); } void Arguments::Writer::endArray() { advanceState(cstring(), EndArray); } void Arguments::Writer::beginDict(ArrayOption option) { beginArrayOrDict(BeginDict, option); } void Arguments::Writer::endDict() { advanceState(cstring("}", strlen("}")), EndDict); } #ifdef WITH_DICT_ENTRY void Arguments::Writer::beginDictEntry() { VALID_IF(m_state == BeginDictEntry, Error::MisplacedBeginDictEntry); advanceState(cstring(), BeginDictEntry); } void Arguments::Writer::endDictEntry() { if (!d->m_aggregateStack.empty()) { Private::AggregateInfo &aggregateInfo = d->m_aggregateStack.back(); if (aggregateInfo.aggregateType == BeginDict && aggregateInfo.arr.dictEntryState == Private::RequireEndDictEntry) { advanceState(cstring(), EndDictEntry); return; } } VALID_IF(false, Error::MisplacedEndDictEntry); } #endif void Arguments::Writer::beginStruct() { advanceState(cstring("(", strlen("(")), BeginStruct); } void Arguments::Writer::endStruct() { advanceState(cstring(")", strlen(")")), EndStruct); } void Arguments::Writer::beginVariant() { advanceState(cstring("v", strlen("v")), BeginVariant); } void Arguments::Writer::endVariant() { advanceState(cstring(), EndVariant); } void Arguments::Writer::writeVariantForMessageHeader(char sig) { // Note: the sugnature we're vorking with there is a(yv) // If we know that and can trust the client, this can be very easy and fast... d->m_signature.ptr[3] = 'v'; d->m_signature.length = 4; d->m_signaturePosition = 4; d->reserveData(d->m_dataPosition + 3, &m_state); d->m_data[d->m_dataPosition++] = 1; d->m_data[d->m_dataPosition++] = sig; d->m_data[d->m_dataPosition++] = 0; } void Arguments::Writer::fixupAfterWriteVariantForMessageHeader() { // We just wrote something to the main signature when we shouldn't have. d->m_signature.length = 4; d->m_signaturePosition = 4; } static char letterForPrimitiveIoState(Arguments::IoState ios) { if (ios < Arguments::Boolean || ios > Arguments::Double) { return 'c'; // a known invalid letter that won't trip up typeInfo() } static const char letters[] = { 'b', // Boolean 'y', // Byte 'n', // Int16 'q', // Uint16 'i', // Int32 'u', // Uint32 'x', // Int64 't', // Uint64 'd' // Double }; return letters[size_t(ios) - size_t(Arguments::Boolean)]; // TODO do we need the casts? } void Arguments::Writer::writePrimitiveArray(IoState type, chunk data) { const char letterCode = letterForPrimitiveIoState(type); if (letterCode == 'c') { m_state = InvalidData; d->m_error.setCode(Error::NotPrimitiveType); return; } if (data.length > Arguments::MaxArrayLength) { m_state = InvalidData; d->m_error.setCode(Error::ArrayOrDictTooLong); return; } const TypeInfo elementType = typeInfo(letterCode); if (!isAligned(data.length, elementType.alignment)) { m_state = InvalidData; d->m_error.setCode(Error::CannotEndArrayOrDictHere); return; } beginArray(data.length ? NonEmptyArray : WriteTypesOfEmptyArray); // dummy write to write the signature... m_u.Uint64 = 0; advanceState(cstring(&letterCode, /*length*/ 1), elementType.state()); if (!data.length) { // oh! a nil array (which is valid) endArray(); return; } // undo the dummy write (except for the preceding alignment bytes, if any) d->m_dataPosition -= elementType.alignment; if (d->insideVariant()) { d->m_queuedData.pop_back(); d->m_queuedData.push_back(Private::QueuedDataInfo(elementType.alignment, 0)); } // append the payload d->reserveData(d->m_dataPosition + data.length, &m_state); d->appendBulkData(data); endArray(); } Arguments Arguments::Writer::finish() { // what needs to happen here: // - check if the message can be closed - basically the aggregate stack must be empty // - close the signature by adding the terminating null - // TODO set error in returned Arguments in error cases Arguments args; if (m_state == InvalidData) { - return args; + args.d->m_error = d->m_error; + return args; // heavily relying on NRVO in all returns here! } if (d->m_nesting.total() != 0) { m_state = InvalidData; d->m_error.setCode(Error::CannotEndArgumentsHere); + args.d->m_error = d->m_error; return args; } assert(!d->m_nilArrayNesting); assert(!d->insideVariant()); assert(d->m_signaturePosition <= MaxSignatureLength); // this should have been caught before assert(d->m_signature.ptr == reinterpret_cast(d->m_data) + 1); // Note that we still keep the full SignatureReservedSpace for the main signature, which means // less copying around to shrink the gap between signature and data, but also wastes an enormous // amount of space (relative to the possible minimum) in some cases. It should not be a big space // problem because normally not many D-Bus Message / Arguments instances exist at the same time. d->m_signature.length = d->m_signaturePosition; d->m_signature.ptr[d->m_signature.length] = '\0'; - args.d->m_error = d->m_error; // OK, so this length check is more of a sanity check. The actual limit limits the size of the // full message. Here we take the size of the "payload" and don't add the size of the signature - // why bother doing it accurately when the real check with full information comes later anyway? bool success = true; const uint32 dataSize = d->m_dataPosition - Private::SignatureReservedSpace; if (success && dataSize > Arguments::MaxMessageLength) { success = false; d->m_error.setCode(Error::ArgumentsTooLong); } if (!dataSize || !success) { args.d->m_memOwnership = nullptr; args.d->m_signature = cstring(); args.d->m_data = chunk(); } else { args.d->m_memOwnership = d->m_data; args.d->m_signature = cstring(d->m_data + 1 /* w/o length prefix */, d->m_signature.length); args.d->m_data = chunk(d->m_data + Private::SignatureReservedSpace, dataSize); d->m_data = nullptr; // now owned by Arguments and later freed there } - if (!success) { + if (success) { + args.d->m_fileDescriptors = std::move(d->m_fileDescriptors); + m_state = Finished; + } else { m_state = InvalidData; - return Arguments(); + args.d->m_error = d->m_error; } - args.d->m_fileDescriptors = std::move(d->m_fileDescriptors); - m_state = Finished; - return std::move(args); + return args; } struct ArrayLengthField { uint32 lengthFieldPosition; uint32 dataStartPosition; }; void Arguments::Writer::flushQueuedData() { const uint32 count = d->m_queuedData.size(); assert(count); // just don't call this method otherwise! // Note: if one of signature or data is nonempty, the other must also be nonempty. // Even "empty" things like empty arrays or null strings have a size field, in that case // (for all(?) types) of value zero. // Copy the signature and main data (thus the whole contents) into one allocated block, // which is good to have for performance and simplicity reasons. // The maximum alignment blowup for naturally aligned types is just less than a factor of 2. // Structs and dict entries are always 8 byte aligned so they add a maximum blowup of 7 bytes // each (when they contain a byte). // Those estimates are very conservative (but easy!), so some space optimization is possible. uint32 inPos = d->m_dataPositionBeforeVariant; uint32 outPos = d->m_dataPositionBeforeVariant; byte *const buffer = d->m_data; std::vector lengthFieldStack; for (uint32 i = 0; i < count; i++) { const Private::QueuedDataInfo ei = d->m_queuedData[i]; switch (ei.size) { case 0: { inPos = align(inPos, ei.alignment()); zeroPad(buffer, ei.alignment(), &outPos); } break; default: { assert(ei.size && ei.size <= Private::QueuedDataInfo::LargestSize); inPos = align(inPos, ei.alignment()); zeroPad(buffer, ei.alignment(), &outPos); // copy data chunk memmove(buffer + outPos, buffer + inPos, ei.size); inPos += ei.size; outPos += ei.size; } break; case Private::QueuedDataInfo::ArrayLengthField: { // start of an array // alignment padding before length field inPos = align(inPos, ei.alignment()); zeroPad(buffer, ei.alignment(), &outPos); // reserve length field ArrayLengthField al; al.lengthFieldPosition = outPos; inPos += sizeof(uint32); outPos += sizeof(uint32); // alignment padding before first array element assert(i + 1 < d->m_queuedData.size()); const uint32 contentsAlignment = d->m_queuedData[i + 1].alignment(); inPos = align(inPos, contentsAlignment); zeroPad(buffer, contentsAlignment, &outPos); // array data starts at the first array element position after alignment al.dataStartPosition = outPos; lengthFieldStack.push_back(al); } break; case Private::QueuedDataInfo::ArrayLengthEndMark: { // end of an array // just put the now known array length in front of the array const ArrayLengthField al = lengthFieldStack.back(); const uint32 arrayLength = outPos - al.dataStartPosition; if (arrayLength > Arguments::MaxArrayLength) { m_state = InvalidData; d->m_error.setCode(Error::ArrayOrDictTooLong); i = count + 1; // break out of the loop break; } basic::writeUint32(buffer + al.lengthFieldPosition, arrayLength); lengthFieldStack.pop_back(); } break; case Private::QueuedDataInfo::VariantSignature: { // move the signature and add its null terminator const uint32 length = buffer[inPos] + 1; // + length prefix memmove(buffer + outPos, buffer + inPos, length); buffer[outPos + length] = '\0'; outPos += length + 1; // + null terminator inPos += Private::Private::SignatureReservedSpace; } break; } } assert(m_state == InvalidData || lengthFieldStack.empty()); d->m_dataPosition = outPos; d->m_queuedData.clear(); } std::vector Arguments::Writer::aggregateStack() const { std::vector ret; ret.reserve(d->m_aggregateStack.size()); for (Private::AggregateInfo &aggregate : d->m_aggregateStack) { ret.push_back(aggregate.aggregateType); } return ret; } uint32 Arguments::Writer::aggregateDepth() const { return d->m_aggregateStack.size(); } Arguments::IoState Arguments::Writer::currentAggregate() const { if (d->m_aggregateStack.empty()) { return NotStarted; } return d->m_aggregateStack.back().aggregateType; } chunk Arguments::Writer::peekSerializedData() const { chunk ret; if (isValid() && m_state != InvalidData && d->m_nesting.total() == 0) { ret.ptr = d->m_data + Private::SignatureReservedSpace; ret.length = d->m_dataPosition - Private::SignatureReservedSpace; } return ret; } const std::vector &Arguments::Writer::fileDescriptors() const { return d->m_fileDescriptors; } void Arguments::Writer::writeBoolean(bool b) { m_u.Boolean = b; advanceState(cstring("b", strlen("b")), Boolean); } void Arguments::Writer::writeByte(byte b) { m_u.Byte = b; advanceState(cstring("y", strlen("y")), Byte); } void Arguments::Writer::writeInt16(int16 i) { m_u.Int16 = i; advanceState(cstring("n", strlen("n")), Int16); } void Arguments::Writer::writeUint16(uint16 i) { m_u.Uint16 = i; advanceState(cstring("q", strlen("q")), Uint16); } void Arguments::Writer::writeInt32(int32 i) { m_u.Int32 = i; advanceState(cstring("i", strlen("i")), Int32); } void Arguments::Writer::writeUint32(uint32 i) { m_u.Uint32 = i; advanceState(cstring("u", strlen("u")), Uint32); } void Arguments::Writer::writeInt64(int64 i) { m_u.Int64 = i; advanceState(cstring("x", strlen("x")), Int64); } void Arguments::Writer::writeUint64(uint64 i) { m_u.Uint64 = i; advanceState(cstring("t", strlen("t")), Uint64); } void Arguments::Writer::writeDouble(double d) { m_u.Double = d; advanceState(cstring("d", strlen("d")), Double); } void Arguments::Writer::writeString(cstring string) { m_u.String.ptr = string.ptr; m_u.String.length = string.length; advanceState(cstring("s", strlen("s")), String); } void Arguments::Writer::writeObjectPath(cstring objectPath) { m_u.String.ptr = objectPath.ptr; m_u.String.length = objectPath.length; advanceState(cstring("o", strlen("o")), ObjectPath); } void Arguments::Writer::writeSignature(cstring signature) { m_u.String.ptr = signature.ptr; m_u.String.length = signature.length; advanceState(cstring("g", strlen("g")), Signature); } void Arguments::Writer::writeUnixFd(int32 fd) { m_u.Int32 = fd; advanceState(cstring("h", strlen("h")), UnixFd); } diff --git a/serialization/message.cpp b/serialization/message.cpp index b1db15a..d7988dc 100644 --- a/serialization/message.cpp +++ b/serialization/message.cpp @@ -1,1312 +1,1370 @@ /* Copyright (C) 2013 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #include "message.h" #include "message_p.h" +#include "arguments_p.h" #include "basictypeio.h" #include "malloccache.h" #include "stringtools.h" #ifndef DFERRY_SERDES_ONLY #include "icompletionlistener.h" #include "itransport.h" #endif #include #include #include #include #include #ifdef __unix__ #include #endif #ifdef BIGENDIAN static const byte s_thisMachineEndianness = 'b'; #else static const byte s_thisMachineEndianness = 'l'; #endif struct MsgAllocCaches { MallocCache msgPrivate; MallocCache<256, 4> msgBuffer; }; thread_local static MsgAllocCaches msgAllocCaches; static const byte s_storageForHeader[Message::UnixFdsHeader + 1] = { 0, // dummy entry: there is no enum value for 0 0x80 | 0, // PathHeader 0x80 | 1, // InterfaceHeader 0x80 | 2, // MethodHeader 0x80 | 3, // ErrorNameHeader 0 | 0, // ReplySerialHeader 0x80 | 4, // DestinationHeader 0x80 | 5, // SenderHeader 0x80 | 6, // SignatureHeader 0 | 1 // UnixFdsHeader }; static bool isStringHeader(int field) { return s_storageForHeader[field] & 0x80; } static int indexOfHeader(int field) { return s_storageForHeader[field] & 0x7f; } static const Message::VariableHeader s_stringHeaderAtIndex[VarHeaderStorage::s_stringHeaderCount] = { Message::PathHeader, Message::InterfaceHeader, Message::MethodHeader, Message::ErrorNameHeader, Message::DestinationHeader, Message::SenderHeader, Message::SignatureHeader }; static const Message::VariableHeader s_intHeaderAtIndex[VarHeaderStorage::s_intHeaderCount] = { Message::ReplySerialHeader, Message::UnixFdsHeader }; VarHeaderStorage::VarHeaderStorage() {} // initialization values are in class declaration VarHeaderStorage::VarHeaderStorage(const VarHeaderStorage &other) { // ### very suboptimal for (int i = 0; i < Message::UnixFdsHeader + 1; i++) { Message::VariableHeader vh = static_cast(i); if (other.hasHeader(vh)) { if (isStringHeader(vh)) { setStringHeader(vh, other.stringHeader(vh)); } else { setIntHeader(vh, other.intHeader(vh)); } } } } VarHeaderStorage::~VarHeaderStorage() { for (int i = 0; i < s_stringHeaderCount; i++) { const Message::VariableHeader field = s_stringHeaderAtIndex[i]; if (hasHeader(field)) { // ~basic_string() instead of ~string() to work around a GCC bug stringHeaders()[i].~basic_string(); } } } bool VarHeaderStorage::hasHeader(Message::VariableHeader header) const { return m_headerPresenceBitmap & (1u << header); } bool VarHeaderStorage::hasStringHeader(Message::VariableHeader header) const { return hasHeader(header) && isStringHeader(header); } bool VarHeaderStorage::hasIntHeader(Message::VariableHeader header) const { return hasHeader(header) && !isStringHeader(header); } std::string VarHeaderStorage::stringHeader(Message::VariableHeader header) const { return hasStringHeader(header) ? stringHeaders()[indexOfHeader(header)] : std::string(); } cstring VarHeaderStorage::stringHeaderRaw(Message::VariableHeader header) { // this one is supposed to be a const method in the intended use, but it is dangerous so // outwardly non-const is kind of okay as a warning cstring ret; assert(isStringHeader(header)); if (hasHeader(header)) { std::string &str = stringHeaders()[indexOfHeader(header)]; ret.ptr = const_cast(str.c_str()); ret.length = str.length(); } return ret; } void VarHeaderStorage::setStringHeader(Message::VariableHeader header, const std::string &value) { if (!isStringHeader(header)) { return; } const int idx = indexOfHeader(header); if (hasHeader(header)) { stringHeaders()[idx] = value; } else { m_headerPresenceBitmap |= 1u << header; new(stringHeaders() + idx) std::string(value); } } bool VarHeaderStorage::setStringHeader_deser(Message::VariableHeader header, cstring value) { assert(isStringHeader(header)); if (hasHeader(header)) { return false; } m_headerPresenceBitmap |= 1u << header; new(stringHeaders() + indexOfHeader(header)) std::string(value.ptr, value.length); return true; } void VarHeaderStorage::clearStringHeader(Message::VariableHeader header) { if (!isStringHeader(header)) { return; } if (hasHeader(header)) { m_headerPresenceBitmap &= ~(1u << header); stringHeaders()[indexOfHeader(header)].~basic_string(); } } uint32 VarHeaderStorage::intHeader(Message::VariableHeader header) const { return hasIntHeader(header) ? m_intHeaders[indexOfHeader(header)] : 0; } void VarHeaderStorage::setIntHeader(Message::VariableHeader header, uint32 value) { if (isStringHeader(header)) { return; } m_headerPresenceBitmap |= 1u << header; m_intHeaders[indexOfHeader(header)] = value; } bool VarHeaderStorage::setIntHeader_deser(Message::VariableHeader header, uint32 value) { assert(!isStringHeader(header)); if (hasHeader(header)) { return false; } m_headerPresenceBitmap |= 1u << header; m_intHeaders[indexOfHeader(header)] = value; return true; } void VarHeaderStorage::clearIntHeader(Message::VariableHeader header) { if (isStringHeader(header)) { return; } m_headerPresenceBitmap &= ~(1u << header); } // TODO think of copying signature from and to output! MessagePrivate::MessagePrivate(Message *parent) : m_message(parent), m_bufferPos(0), m_isByteSwapped(false), m_state(Empty), m_messageType(Message::InvalidMessage), m_flags(0), m_protocolVersion(1), m_dirty(true), m_headerLength(0), m_headerPadding(0), m_bodyLength(0), m_serial(0) {} MessagePrivate::MessagePrivate(const MessagePrivate &other, Message *parent) : m_message(parent), m_bufferPos(other.m_bufferPos), m_isByteSwapped(other.m_isByteSwapped), m_state(other.m_state), m_messageType(other.m_messageType), m_flags(other.m_flags), m_protocolVersion(other.m_protocolVersion), m_dirty(other.m_dirty), m_headerLength(other.m_headerLength), m_headerPadding(other.m_headerPadding), m_bodyLength(other.m_bodyLength), m_serial(other.m_serial), m_error(other.m_error), m_mainArguments(other.m_mainArguments), m_varHeaders(other.m_varHeaders) { if (other.m_buffer.ptr) { // we don't keep pointers into the buffer (only indexes), right? right? m_buffer.ptr = static_cast(malloc(other.m_buffer.length)); m_buffer.length = other.m_buffer.length; // Simplification: don't try to figure out which part of other.m_buffer contains "valid" data, // just copy everything. memcpy(m_buffer.ptr, other.m_buffer.ptr, other.m_buffer.length); #ifdef __unix__ // TODO ensure all "actual" file descriptor handling everywhere is inside this ifdef // (note conditional compilation of whole file localsocket.cpp) - m_fileDescriptors.clear(); - for (int fd : other.m_fileDescriptors) { + argUnixFds()->clear(); + std::vector *otherUnixFds = const_cast(other).argUnixFds(); + argUnixFds()->reserve(otherUnixFds->size()); + for (int fd : *otherUnixFds) { int fdCopy = ::dup(fd); if (fdCopy == -1) { // TODO error... } - m_fileDescriptors.push_back(fdCopy); + argUnixFds()->push_back(fdCopy); } #endif } else { assert(!m_buffer.length); } // ### Maybe warn when copying a Message which is currently (de)serializing. It might even be impossible // to do that from client code. If that is the case, the "warning" could even be an assertion because // we should never do such a thing. } MessagePrivate::~MessagePrivate() { - clearBuffer(); + clear(/* onlyReleaseResources = */ true); } Message::Message() : d(new(msgAllocCaches.msgPrivate.allocate()) MessagePrivate(this)) { } Message::Message(Message &&other) : d(other.d) { other.d = nullptr; d->m_message = this; } Message &Message::operator=(Message &&other) { if (this != &other) { if (d) { d->~MessagePrivate(); msgAllocCaches.msgPrivate.free(d); } d = other.d; if (other.d) { other.d = nullptr; d->m_message = this; } } return *this; } Message::Message(const Message &other) : d(nullptr) { if (!other.d) { return; } d = new(msgAllocCaches.msgPrivate.allocate()) MessagePrivate(*other.d, this); } Message &Message::operator=(const Message &other) { if (this != &other) { if (d) { d->~MessagePrivate(); msgAllocCaches.msgPrivate.free(d); } if (other.d) { // ### can be optimized by implementing and using assignment of MessagePrivate d = new(msgAllocCaches.msgPrivate.allocate()) MessagePrivate(*other.d, this); } else { d = nullptr; } } return *this; } Message::~Message() { if (d) { d->~MessagePrivate(); msgAllocCaches.msgPrivate.free(d); d = nullptr; } } Error Message::error() const { return d->m_error; } void Message::setCall(const std::string &path, const std::string &interface, const std::string &method) { setType(MethodCallMessage); setPath(path); setInterface(interface); setMethod(method); } void Message::setCall(const std::string &path, const std::string &method) { setType(MethodCallMessage); setPath(path); setMethod(method); } void Message::setReplyTo(const Message &call) { setType(MethodReturnMessage); setDestination(call.sender()); setReplySerial(call.serial()); } void Message::setErrorReplyTo(const Message &call, const std::string &errorName) { setType(ErrorMessage); setErrorName(errorName); setDestination(call.sender()); setReplySerial(call.serial()); } void Message::setSignal(const std::string &path, const std::string &interface, const std::string &method) { setType(SignalMessage); setPath(path); setInterface(interface); setMethod(method); } Message Message::createCall(const std::string &path, const std::string &interface, const std::string &method) { Message ret; ret.setCall(path, interface, method); return ret; } Message Message::createCall(const std::string &path, const std::string &method) { Message ret; ret.setCall(path, method); return ret; } Message Message::createReplyTo(const Message &call) { Message ret; ret.setReplyTo(call); return ret; } Message Message::createErrorReplyTo(const Message &call, const std::string &errorName) { Message ret; ret.setErrorReplyTo(call, errorName); return ret; } Message Message::createSignal(const std::string &path, const std::string &interface, const std::string &method) { Message ret; ret.setSignal(path, interface, method); return ret; } struct VarHeaderPrinter { Message::VariableHeader field; const char *name; }; static const int stringHeadersCount = 7; static VarHeaderPrinter stringHeaderPrinters[stringHeadersCount] = { { Message::PathHeader, "path" }, { Message::InterfaceHeader, "interface" }, { Message::MethodHeader, "method" }, { Message::ErrorNameHeader, "error name" }, { Message::DestinationHeader, "destination" }, { Message::SenderHeader, "sender" }, { Message::SignatureHeader, "signature" } }; static const int intHeadersCount = 2; static VarHeaderPrinter intHeaderPrinters[intHeadersCount] = { { Message::ReplySerialHeader, "reply serial" }, { Message::UnixFdsHeader, "#unix fds" } }; static const int messageTypeCount = 5; static const char *printableMessageTypes[messageTypeCount] = { "", // handled in code "Method call", "Method return", "Method error return", "Signal" }; std::string Message::prettyPrint() const { std::string ret; if (d->m_messageType >= 1 && d->m_messageType < messageTypeCount) { ret += printableMessageTypes[d->m_messageType]; } else { return std::string("Invalid message.\n"); } std::ostringstream os; for (int i = 0; i < stringHeadersCount; i++ ) { bool isPresent = false; std::string str = stringHeader(stringHeaderPrinters[i].field, &isPresent); if (isPresent) { os << "; " << stringHeaderPrinters[i].name << ": \"" << str << '"'; } } for (int i = 0; i < intHeadersCount; i++ ) { bool isPresent = false; uint32 intValue = intHeader(intHeaderPrinters[i].field, &isPresent); if (isPresent) { os << "; " << intHeaderPrinters[i].name << ": " << intValue; } } ret += os.str(); ret += '\n'; ret += d->m_mainArguments.prettyPrint(); return ret; } Message::Type Message::type() const { return d->m_messageType; } void Message::setType(Type type) { if (d->m_messageType == type) { return; } d->m_dirty = true; d->m_messageType = type; setExpectsReply(d->m_messageType == MethodCallMessage); } uint32 Message::protocolVersion() const { return d->m_protocolVersion; } void Message::setSerial(uint32 serial) { d->m_serial = serial; if (d->m_state == MessagePrivate::Serialized && !d->m_dirty) { // performance hack: setSerial is likely to happen just before sending - don't re-serialize, // just patch it. byte *p = d->m_buffer.ptr; basic::writeUint32(p + 4 /* bytes */ + sizeof(uint32), d->m_serial); return; } d->m_dirty = true; } uint32 Message::serial() const { return d->m_serial; } std::string Message::path() const { return stringHeader(PathHeader, 0); } void Message::setPath(const std::string &path) { setStringHeader(PathHeader, path); } std::string Message::interface() const { return stringHeader(InterfaceHeader, 0); } void Message::setInterface(const std::string &interface) { setStringHeader(InterfaceHeader, interface); } std::string Message::method() const { return stringHeader(MethodHeader, 0); } void Message::setMethod(const std::string &method) { setStringHeader(MethodHeader, method); } std::string Message::errorName() const { return stringHeader(ErrorNameHeader, 0); } void Message::setErrorName(const std::string &errorName) { setStringHeader(ErrorNameHeader, errorName); } uint32 Message::replySerial() const { return intHeader(ReplySerialHeader, 0); } void Message::setReplySerial(uint32 replySerial) { setIntHeader(ReplySerialHeader, replySerial); } std::string Message::destination() const { return stringHeader(DestinationHeader, 0); } void Message::setDestination(const std::string &destination) { setStringHeader(DestinationHeader, destination); } std::string Message::sender() const { return stringHeader(SenderHeader, 0); } void Message::setSender(const std::string &sender) { setStringHeader(SenderHeader, sender); } std::string Message::signature() const { return stringHeader(SignatureHeader, 0); } uint32 Message::unixFdCount() const { return intHeader(UnixFdsHeader, 0); } std::string Message::stringHeader(VariableHeader header, bool *isPresent) const { const bool exists = d->m_varHeaders.hasStringHeader(header); if (isPresent) { *isPresent = exists; } return exists ? d->m_varHeaders.stringHeader(header) : std::string(); } void Message::setStringHeader(VariableHeader header, const std::string &value) { if (header == SignatureHeader) { // ### warning? - this is a public method, and setting the signature separately does not make sense return; } d->m_dirty = true; d->m_varHeaders.setStringHeader(header, value); } uint32 Message::intHeader(VariableHeader header, bool *isPresent) const { const bool exists = d->m_varHeaders.hasIntHeader(header); if (isPresent) { *isPresent = exists; } return d->m_varHeaders.intHeader(header); } void Message::setIntHeader(VariableHeader header, uint32 value) { d->m_dirty = true; d->m_varHeaders.setIntHeader(header, value); } bool Message::expectsReply() const { return (d->m_flags & MessagePrivate::NoReplyExpectedFlag) == 0; } void Message::setExpectsReply(bool expectsReply) { if (expectsReply) { d->m_flags &= ~MessagePrivate::NoReplyExpectedFlag; } else { d->m_flags |= MessagePrivate::NoReplyExpectedFlag; } } bool Message::autoStartService() const { return (d->m_flags & MessagePrivate::NoAutoStartServiceFlag) == 0; } void Message::setAutoStartService(bool autoStart) const { if (autoStart) { d->m_flags &= ~MessagePrivate::NoAutoStartServiceFlag; } else { d->m_flags |= MessagePrivate::NoAutoStartServiceFlag; } } bool Message::interactiveAuthorizationAllowed() const { return (d->m_flags & MessagePrivate::NoAllowInteractiveAuthorizationFlag) == 0; } void Message::setInteractiveAuthorizationAllowed(bool allowInteractive) const { if (allowInteractive) { d->m_flags &= ~MessagePrivate::NoAllowInteractiveAuthorizationFlag; } else { d->m_flags |= MessagePrivate::NoAllowInteractiveAuthorizationFlag; } } void Message::setArguments(Arguments arguments) { d->m_dirty = true; d->m_error = arguments.error(); const size_t fdCount = arguments.fileDescriptors().size(); if (fdCount) { d->m_varHeaders.setIntHeader(Message::UnixFdsHeader, fdCount); } else { d->m_varHeaders.clearIntHeader(Message::UnixFdsHeader); } cstring signature = arguments.signature(); if (signature.length) { d->m_varHeaders.setStringHeader(Message::SignatureHeader, toStdString(signature)); } else { d->m_varHeaders.clearStringHeader(Message::SignatureHeader); } d->m_mainArguments = std::move(arguments); } const Arguments &Message::arguments() const { return d->m_mainArguments; } static const uint32 s_properFixedHeaderLength = 12; static const uint32 s_extendedFixedHeaderLength = 16; #ifndef DFERRY_SERDES_ONLY void MessagePrivate::receive(ITransport *transport) { if (m_state >= FirstIoState) { // Can only do one I/O operation at a time return; } transport->setReadListener(this); m_state = MessagePrivate::Receiving; m_headerLength = 0; m_bodyLength = 0; } bool Message::isReceiving() const { return d->m_state == MessagePrivate::Receiving; } void MessagePrivate::send(ITransport *transport) { if (!serialize()) { - // TODO - // m_error.setCode(); - // notifyCompletionListener(); would call into Connection, but it's easier for Connection to handle - // the error from non-callback code, directly in the caller of send(). + m_state = Serialized; + assert(m_error.isError()); + if (!m_error.isError()) { + // TODO This error code makes no sense here. We don't have a "generic error" code, and a specific + // error code should have been set by whatever code detected the error. Hence the assertion. + m_error = Error::MalformedReply; + } + // Note: We don't call notifyCompletionListener(); since all our timers are internal, we can + // expect them to check for errors after this returns. Specifically, Connection::send() has + // access to the timer in the PendingReply, so it can produce a deferred error notification + // just like success notifications are deferred. return; } if (m_state != MessagePrivate::Sending) { transport->setWriteListener(this); m_state = MessagePrivate::Sending; } } bool Message::isSending() const { return d->m_state == MessagePrivate::Sending; } void MessagePrivate::setCompletionListener(ICompletionListener *listener) { m_completionListener = listener; } void MessagePrivate::notifyCompletionListener() { if (m_completionListener) { m_completionListener->handleCompletion(m_message); } } IO::Status MessagePrivate::handleTransportCanRead() { if (m_state != Receiving) { return IO::Status::InternalError; } IO::Status ret = IO::Status::OK; IO::Result ioRes; do { uint32 readMax = 0; if (!m_headerLength) { // the message might only consist of the header, so we must be careful to avoid reading // data meant for the next message readMax = s_extendedFixedHeaderLength - m_bufferPos; } else { // reading variable headers and/or body readMax = m_headerLength + m_bodyLength - m_bufferPos; } reserveBuffer(m_bufferPos + readMax); const bool headersDone = m_headerLength > 0 && m_bufferPos >= m_headerLength; if (m_bufferPos == 0) { - // File descriptors should arrive only with the first byte + // According to the DBus spec, file descriptors can arrive anywhere in the message, but + // (assuming the message is written in one sendmsg() call) according to UNIX domain socket + // documentation, file descriptors will arrive in the first byte of the sent block they are + // attached to. We go with the UNIX domain socket documentation. + // TODO review and test this for very large messages that cannot be sent in one call ioRes = readTransport()->readWithFileDescriptors(m_buffer.ptr + m_bufferPos, readMax, - &m_fileDescriptors); + argUnixFds()); } else { ioRes = readTransport()->read(m_buffer.ptr + m_bufferPos, readMax); } m_bufferPos += ioRes.length; assert(m_bufferPos <= m_buffer.length); if (!headersDone) { if (m_headerLength == 0 && m_bufferPos >= s_extendedFixedHeaderLength) { if (!deserializeFixedHeaders()) { - ret = IO::Status::InternalError; // TODO ... m_error = Error::MalformetReply? + ret = IO::Status::RemoteClosed; + m_error = Error::MalformedReply; break; } } if (m_headerLength > 0 && m_bufferPos >= m_headerLength) { - if (deserializeVariableHeaders()) { - const uint32 fdsCount = m_varHeaders.intHeader(Message::UnixFdsHeader); - if (fdsCount != m_fileDescriptors.size()) { - ret = IO::Status::InternalError; // TODO - break; - } - } else { - ret = IO::Status::InternalError; // TODO + // ### If we expected to receive the FDs at any point in the message (as opposed to just the + // first byte), we'd have to verify FD count later. But we don't, so this is expedient. + if (!deserializeVariableHeaders() || + m_varHeaders.intHeader(Message::UnixFdsHeader) != argUnixFds()->size()) { + ret = IO::Status::RemoteClosed; + m_error = Error::MalformedReply; break; } } } if (m_headerLength > 0 && m_bufferPos >= m_headerLength + m_bodyLength) { // all done! assert(m_bufferPos == m_headerLength + m_bodyLength); m_state = Serialized; chunk bodyData(m_buffer.ptr + m_headerLength, m_bodyLength); m_mainArguments = Arguments(nullptr, m_varHeaders.stringHeaderRaw(Message::SignatureHeader), - bodyData, std::move(m_fileDescriptors), m_isByteSwapped); - m_fileDescriptors.clear(); // put it into a well-defined state + bodyData, std::move(*argUnixFds()), m_isByteSwapped); assert(ioRes.status == IO::Status::OK && ret == IO::Status::OK); readTransport()->setReadListener(nullptr); notifyCompletionListener(); // do not access members after this because it might delete us! break; } if (!readTransport()->isOpen()) { - ret = IO::Status::InternalError; // TODO + ret = IO::Status::RemoteClosed; break; } } while (ioRes.status == IO::Status::OK); if (ret != IO::Status::OK) { - m_state = Empty; - clearBuffer(); + clear(); readTransport()->setReadListener(nullptr); - m_error = Error::RemoteDisconnect; + if (!m_error.isError()) { + // catch-all, we know that SOME error happened + m_error = Error::RemoteDisconnect; + } notifyCompletionListener(); - // TODO reset other data members, SET ERROR, generally revisit error handling to make it robust } return ret; } IO::Status MessagePrivate::handleTransportCanWrite() { if (m_state != Sending) { return IO::Status::InternalError; } while (true) { assert(m_buffer.length >= m_bufferPos); const uint32 toWrite = m_buffer.length - m_bufferPos; if (!toWrite) { m_state = Serialized; writeTransport()->setWriteListener(nullptr); notifyCompletionListener(); break; } IO::Result ioRes; if (m_bufferPos == 0) { - // Send file descriptors and / or credentials with first byte of message. We could call - // write() after checking that we don't need to send fds (easy) or credentials (which would - // be a slight layering violation). - ioRes = writeTransport()->writeWithFileDescriptors(chunk(m_buffer.ptr + m_bufferPos, toWrite), - m_mainArguments.fileDescriptors()); + const size_t sendFdsCount = m_mainArguments.fileDescriptors().size(); + if (sendFdsCount == 0) { + ioRes = writeTransport()->write(chunk(m_buffer.ptr + m_bufferPos, toWrite)); + } else if (sendFdsCount > writeTransport()->supportedPassingUnixFdsCount()) { + m_error.setCode(Error::SendingTooManyUnixFds); + m_state = Serialized; + writeTransport()->setWriteListener(nullptr); + // ### Oh well, now we have a special Error value to pass through the stack + // (for error handling), but also notifyCompletionListener() for sucessful completion + // handling. Can we get rid of one or the other, or are there actually good reasons for + // the difference? + // Pro separation: + // - the arguments I came up with for not doing error handling through callbacks + // - the ...convenience?... of using callbacks + // - possibly better performance of happy path + // + // Contra separation: + // - two different mechanisms for similar things... + // - inconsistency - though is that really a problem? the situations are different. + // - suddenly IO code needs to ~know (at least pass through and be technically exposed to) + // error values it doesn't know and can't handle itself; theoretically could use some + // error value wrapping mechanism to pass through opaque errors). + return IO::Status::PayloadError; // the connection is fine, only this message has a problem + } else { + ioRes = writeTransport() + ->writeWithFileDescriptors(chunk(m_buffer.ptr + m_bufferPos, toWrite), + m_mainArguments.fileDescriptors()); + } } else { ioRes = writeTransport()->write(chunk(m_buffer.ptr + m_bufferPos, toWrite)); } if (ioRes.status != IO::Status::OK) { + // ### what about m_error? + // I think we only test remote disconnect while reading. We should also check while writing. + // - Well, actually, Connection aborts all pending replies with error when we return an error + // her. But due to th limited amount of info in IO::Status, it can only report errors + // corresponding directly to IO::Status values. + m_state = Serialized; // in a way... serialization has completed, unsuccessfully writeTransport()->setWriteListener(nullptr); - // TODO notifyCompletionListener() for failure? - // TODO state update? + notifyCompletionListener(); return IO::Status::RemoteClosed; - break; } m_bufferPos += ioRes.length; } return IO::Status::OK; } #endif // !DFERRY_SERDES_ONLY chunk Message::serializeAndView() { chunk ret; // one return variable to enable return value optimization (RVO) in gcc if (!d->serialize()) { - // TODO report error? return ret; } ret = d->m_buffer; return ret; } std::vector Message::save() { std::vector ret; if (!d->serialize()) { return ret; } ret.reserve(d->m_buffer.length); for (uint32 i = 0; i < d->m_buffer.length; i++) { ret.push_back(d->m_buffer.ptr[i]); } return ret; } void Message::deserializeAndTake(chunk memOwnership) { if (d->m_state >= MessagePrivate::FirstIoState) { free(memOwnership.ptr); return; } d->m_headerLength = 0; d->m_bodyLength = 0; d->clearBuffer(); d->m_buffer = memOwnership; d->m_bufferPos = d->m_buffer.length; bool ok = d->m_buffer.length >= s_extendedFixedHeaderLength; ok = ok && d->deserializeFixedHeaders(); ok = ok && d->m_buffer.length >= d->m_headerLength; ok = ok && d->deserializeVariableHeaders(); ok = ok && d->m_buffer.length == d->m_headerLength + d->m_bodyLength; if (!ok) { - d->m_state = MessagePrivate::Empty; - d->clearBuffer(); + if (!d->m_error.isError()) { + d->m_error = Error::MalformedReply; + } + d->clear(); return; } chunk bodyData(d->m_buffer.ptr + d->m_headerLength, d->m_bodyLength); d->m_mainArguments = Arguments(nullptr, d->m_varHeaders.stringHeaderRaw(SignatureHeader), bodyData, d->m_isByteSwapped); d->m_state = MessagePrivate::Serialized; } // This does not return bool because full validation of the main arguments would take quite // a few cycles. Validating only the header of the message doesn't seem to be worth it. void Message::load(const std::vector &data) { if (d->m_state >= MessagePrivate::FirstIoState || data.empty()) { return; } chunk buf; buf.length = data.size(); buf.ptr = reinterpret_cast(malloc(buf.length)); deserializeAndTake(buf); } bool MessagePrivate::requiredHeadersPresent() { m_error = checkRequiredHeaders(); return !m_error.isError(); } Error MessagePrivate::checkRequiredHeaders() const { if (m_serial == 0) { return Error::MessageSerial; } if (m_protocolVersion != 1) { return Error::MessageProtocolVersion; } // might want to check for DestinationHeader if the transport is a bus (not peer-to-peer) // very strange that this isn't in the spec! switch (m_messageType) { case Message::SignalMessage: // required: PathHeader, InterfaceHeader, MethodHeader if (!m_varHeaders.hasStringHeader(Message::InterfaceHeader)) { return Error::MessageInterface; } // fall through case Message::MethodCallMessage: // required: PathHeader, MethodHeader if (!m_varHeaders.hasStringHeader(Message::PathHeader)) { return Error::MessagePath; } if (!m_varHeaders.hasStringHeader(Message::MethodHeader)) { return Error::MessageMethod; } break; case Message::ErrorMessage: // required: ErrorNameHeader, ReplySerialHeader if (!m_varHeaders.hasStringHeader(Message::ErrorNameHeader)) { return Error::MessageErrorName; } // fall through case Message::MethodReturnMessage: // required: ReplySerialHeader if (!m_varHeaders.hasIntHeader(Message::ReplySerialHeader) ) { return Error::MessageReplySerial; } break; case Message::InvalidMessage: default: return Error::MessageType; } return Error::NoError; } bool MessagePrivate::deserializeFixedHeaders() { assert(m_bufferPos >= s_extendedFixedHeaderLength); byte *p = m_buffer.ptr; byte endianness = *p++; if (endianness != 'l' && endianness != 'B') { return false; } m_isByteSwapped = endianness != s_thisMachineEndianness; // TODO validate the values read here m_messageType = static_cast(*p++); m_flags = *p++; m_protocolVersion = *p++; m_bodyLength = basic::readUint32(p, m_isByteSwapped); m_serial = basic::readUint32(p + sizeof(uint32), m_isByteSwapped); // peek into the var-length header and use knowledge about array serialization to infer the // number of bytes still required for the header uint32 varArrayLength = basic::readUint32(p + 2 * sizeof(uint32), m_isByteSwapped); uint32 unpaddedHeaderLength = s_extendedFixedHeaderLength + varArrayLength; m_headerLength = align(unpaddedHeaderLength, 8); m_headerPadding = m_headerLength - unpaddedHeaderLength; return m_headerLength + m_bodyLength <= Arguments::MaxMessageLength; } bool MessagePrivate::deserializeVariableHeaders() { // use Arguments to parse the variable header fields // HACK: the fake first int argument is there to start the Arguments's data 8 byte aligned byte *base = m_buffer.ptr + s_properFixedHeaderLength - sizeof(int32); chunk headerData(base, m_headerLength - m_headerPadding - s_properFixedHeaderLength + sizeof(int32)); cstring varHeadersSig("ia(yv)"); Arguments argList(nullptr, varHeadersSig, headerData, m_isByteSwapped); Arguments::Reader reader(argList); assert(reader.isValid()); if (reader.state() != Arguments::Int32) { return false; } reader.readInt32(); if (reader.state() != Arguments::BeginArray) { return false; } reader.beginArray(); while (reader.state() == Arguments::BeginStruct) { reader.beginStruct(); const byte headerField = reader.readByte(); if (headerField < Message::PathHeader || headerField > Message::UnixFdsHeader) { return false; } const Message::VariableHeader eHeader = static_cast(headerField); reader.beginVariant(); bool ok = true; // short-circuit evaluation ftw if (isStringHeader(headerField)) { if (headerField == Message::PathHeader) { ok = ok && reader.state() == Arguments::ObjectPath; ok = ok && m_varHeaders.setStringHeader_deser(eHeader, reader.readObjectPath()); } else if (headerField == Message::SignatureHeader) { ok = ok && reader.state() == Arguments::Signature; // The spec allows having no signature header, which means "empty signature". However... // We do not drop empty signature headers when deserializing, in order to preserve // the original message contents. This could be useful for debugging and testing. ok = ok && m_varHeaders.setStringHeader_deser(eHeader, reader.readSignature()); } else { ok = ok && reader.state() == Arguments::String; ok = ok && m_varHeaders.setStringHeader_deser(eHeader, reader.readString()); } } else { ok = ok && reader.state() == Arguments::Uint32; ok = ok && m_varHeaders.setIntHeader_deser(eHeader, reader.readUint32()); } if (!ok) { return false; } reader.endVariant(); reader.endStruct(); } reader.endArray(); // check that header->body padding is in fact zero filled base = m_buffer.ptr; for (uint32 i = m_headerLength - m_headerPadding; i < m_headerLength; i++) { if (base[i] != '\0') { return false; } } return reader.isFinished(); } bool MessagePrivate::serialize() { if ((m_state == Serialized || m_state == Sending) && !m_dirty) { return true; } if (m_state >= FirstIoState) { // Marshalled data must not be touched while doing I/O return false; } clearBuffer(); if (m_error.isError() || !requiredHeadersPresent()) { return false; } Arguments headerArgs = serializeVariableHeaders(); + if (headerArgs.data().length <= 0) { + return false; + } // we need to cut out alignment padding bytes 4 to 7 in the variable header data stream because // the original dbus code aligns based on address in the final data stream // (offset s_properFixedHeaderLength == 12), we align based on address in the Arguments's buffer // (offset 0) - note that our modification keeps the stream valid because length is measured from end // of padding assert(headerArgs.data().length > 0); // if this fails the headerLength hack will break down const uint32 unalignedHeaderLength = s_properFixedHeaderLength + headerArgs.data().length - sizeof(uint32); m_headerLength = align(unalignedHeaderLength, 8); m_bodyLength = m_mainArguments.data().length; const uint32 messageLength = m_headerLength + m_bodyLength; if (messageLength > Arguments::MaxMessageLength) { m_error.setCode(Error::ArgumentsTooLong); return false; } reserveBuffer(messageLength); serializeFixedHeaders(); // copy header data: uint32 length... memcpy(m_buffer.ptr + s_properFixedHeaderLength, headerArgs.data().ptr, sizeof(uint32)); // skip four bytes of padding and copy the rest memcpy(m_buffer.ptr + s_properFixedHeaderLength + sizeof(uint32), headerArgs.data().ptr + 2 * sizeof(uint32), headerArgs.data().length - 2 * sizeof(uint32)); // zero padding between variable headers and message body for (uint32 i = unalignedHeaderLength; i < m_headerLength; i++) { m_buffer.ptr[i] = '\0'; } // copy message body (if any - arguments are not mandatory) if (m_mainArguments.data().length) { memcpy(m_buffer.ptr + m_headerLength, m_mainArguments.data().ptr, m_mainArguments.data().length); } m_bufferPos = m_headerLength + m_mainArguments.data().length; assert(m_bufferPos <= m_buffer.length); // for the upcoming message sending, "reuse" m_bufferPos for read position (formerly write position), // and m_buffer.length for end of data to read (formerly buffer capacity) m_buffer.length = m_bufferPos; m_bufferPos = 0; m_dirty = false; m_state = Serialized; return true; } void MessagePrivate::serializeFixedHeaders() { assert(m_buffer.length >= s_extendedFixedHeaderLength); byte *p = m_buffer.ptr; *p++ = s_thisMachineEndianness; *p++ = byte(m_messageType); *p++ = m_flags; *p++ = m_protocolVersion; basic::writeUint32(p, m_bodyLength); basic::writeUint32(p + sizeof(uint32), m_serial); } static void doVarHeaderPrologue(Arguments::Writer *writer, Message::VariableHeader field) { writer->beginStruct(); writer->writeByte(byte(field)); } Arguments MessagePrivate::serializeVariableHeaders() { Arguments::Writer writer; // note that we don't have to deal with empty arrays because all valid message types require // at least one of the variable headers writer.beginArray(); for (int i = 0; i < VarHeaderStorage::s_stringHeaderCount; i++) { const Message::VariableHeader field = s_stringHeaderAtIndex[i]; if (m_varHeaders.hasHeader(field)) { doVarHeaderPrologue(&writer, field); const std::string &str = m_varHeaders.stringHeaders()[i]; if (field == Message::PathHeader) { writer.writeVariantForMessageHeader('o'); writer.writeObjectPath(cstring(str.c_str(), str.length())); } else if (field == Message::SignatureHeader) { writer.writeVariantForMessageHeader('g'); writer.writeSignature(cstring(str.c_str(), str.length())); } else { writer.writeVariantForMessageHeader('s'); writer.writeString(cstring(str.c_str(), str.length())); } writer.fixupAfterWriteVariantForMessageHeader(); writer.endStruct(); if (unlikely(writer.error().isError())) { static const Error::Code stringHeaderErrors[VarHeaderStorage::s_stringHeaderCount] = { Error::MessagePath, Error::MessageInterface, Error::MessageMethod, Error::MessageErrorName, Error::MessageDestination, Error::MessageSender, Error::MessageSignature }; m_error.setCode(stringHeaderErrors[i]); return Arguments(); } } } for (int i = 0; i < VarHeaderStorage::s_intHeaderCount; i++) { const Message::VariableHeader field = s_intHeaderAtIndex[i]; if (m_varHeaders.hasHeader(field)) { doVarHeaderPrologue(&writer, field); writer.writeVariantForMessageHeader('u'); writer.writeUint32(m_varHeaders.m_intHeaders[i]); writer.fixupAfterWriteVariantForMessageHeader(); writer.endStruct(); } } writer.endArray(); return writer.finish(); } void MessagePrivate::clearBuffer() { if (m_buffer.ptr) { free(m_buffer.ptr); m_buffer = chunk(); m_bufferPos = 0; } else { assert(m_buffer.length == 0); assert(m_bufferPos == 0); } +} + +void MessagePrivate::clear(bool onlyReleaseResources) +{ + clearBuffer(); #ifdef __unix__ - for (int fd : m_fileDescriptors) { + for (int fd : *argUnixFds()) { ::close(fd); } #endif - m_fileDescriptors.clear(); + if (!onlyReleaseResources) { // get into a clean state again + m_state = Empty; + m_mainArguments = Arguments(); + m_varHeaders = VarHeaderStorage(); + } } static uint32 nextPowerOf2(uint32 x) { --x; x |= x >> 1; x |= x >> 2; x |= x >> 4; x |= x >> 8; x |= x >> 16; return ++x; } void MessagePrivate::reserveBuffer(uint32 newLen) { const uint32 oldLen = m_buffer.length; if (newLen <= oldLen) { return; } if (newLen <= 256) { assert(oldLen == 0); newLen = 256; m_buffer.ptr = reinterpret_cast(msgAllocCaches.msgBuffer.allocate()); } else { newLen = nextPowerOf2(newLen); if (oldLen == 256) { byte *newAlloc = reinterpret_cast(malloc(newLen)); memcpy(newAlloc, m_buffer.ptr, oldLen); msgAllocCaches.msgBuffer.free(m_buffer.ptr); m_buffer.ptr = newAlloc; } else { m_buffer.ptr = reinterpret_cast(realloc(m_buffer.ptr, newLen)); } } m_buffer.length = newLen; } + +std::vector *MessagePrivate::argUnixFds() +{ + return &Arguments::Private::get(&m_mainArguments)->m_fileDescriptors; +} diff --git a/serialization/message_p.h b/serialization/message_p.h index 2019fc0..41e2797 100644 --- a/serialization/message_p.h +++ b/serialization/message_p.h @@ -1,148 +1,150 @@ /* Copyright (C) 2014 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #ifndef MESSAGE_P_H #define MESSAGE_P_H #include "message.h" #include "arguments.h" #include "error.h" #include "itransportlistener.h" #include class ICompletionListener; class VarHeaderStorage { public: VarHeaderStorage(); VarHeaderStorage(const VarHeaderStorage &other); ~VarHeaderStorage(); bool hasHeader(Message::VariableHeader header) const; bool hasStringHeader(Message::VariableHeader header) const; std::string stringHeader(Message::VariableHeader header) const; cstring stringHeaderRaw(Message::VariableHeader header); void setStringHeader(Message::VariableHeader header, const std::string &value); void clearStringHeader(Message::VariableHeader header); bool hasIntHeader(Message::VariableHeader header) const; uint32 intHeader(Message::VariableHeader header) const; void setIntHeader(Message::VariableHeader header, uint32 value); void clearIntHeader(Message::VariableHeader header); // for use during header deserialization: returns false if a header occurs twice, // but does not check if the given header is of the right type (int / string). bool setIntHeader_deser(Message::VariableHeader header, uint32 value); bool setStringHeader_deser(Message::VariableHeader header, cstring value); const std::string *stringHeaders() const { return reinterpret_cast(m_stringStorage); } std::string *stringHeaders() { return reinterpret_cast(m_stringStorage); } static const int s_stringHeaderCount = 7; static const int s_intHeaderCount = 2; // Uninitialized storage for strings, to avoid con/destructing strings we'd never touch otherwise. std::aligned_storage::type m_stringStorage[VarHeaderStorage::s_stringHeaderCount]; uint32 m_intHeaders[s_intHeaderCount]; uint32 m_headerPresenceBitmap = 0; }; class MessagePrivate : public ITransportListener { public: static MessagePrivate *get(Message *m) { return m->d; } MessagePrivate(Message *parent); MessagePrivate(const MessagePrivate &other, Message *parent); ~MessagePrivate() override; IO::Status handleTransportCanRead() override; IO::Status handleTransportCanWrite() override; // ITransport is non-public API, so these make no sense in the public interface void receive(ITransport *transport); // fills in this message from transport void send(ITransport *transport); // sends this message over transport // for receive or send completion (it should be clear which because receiving and sending can't // happen simultaneously) void setCompletionListener(ICompletionListener *listener); bool requiredHeadersPresent(); Error checkRequiredHeaders() const; bool deserializeFixedHeaders(); bool deserializeVariableHeaders(); bool serialize(); void serializeFixedHeaders(); Arguments serializeVariableHeaders(); void clearBuffer(); + void clear(bool onlyReleaseResources = false); void reserveBuffer(uint32 newSize); void notifyCompletionListener(); + std::vector *argUnixFds(); + Message *m_message; chunk m_buffer; uint32 m_bufferPos; - std::vector m_fileDescriptors; bool m_isByteSwapped; enum { // ### we don't have an error state, the need hasn't arisen yet. strange! Empty = 0, Serialized, // This means that marshalled and "native format" data is in sync, which is really // the same whether the message was marshalled or demarshalled FirstIoState, Sending = FirstIoState, Receiving } m_state; Message::Type m_messageType; enum { NoReplyExpectedFlag = 0x1, NoAutoStartServiceFlag = 0x2, NoAllowInteractiveAuthorizationFlag = 0x4 }; byte m_flags; byte m_protocolVersion; bool m_dirty : 1; uint32 m_headerLength; uint32 m_headerPadding; uint32 m_bodyLength; uint32 m_serial; Error m_error; Arguments m_mainArguments; VarHeaderStorage m_varHeaders; ICompletionListener *m_completionListener; }; #endif // MESSAGE_P_H diff --git a/tests/connection/CMakeLists.txt b/tests/connection/CMakeLists.txt index e3be3f6..33027ef 100644 --- a/tests/connection/CMakeLists.txt +++ b/tests/connection/CMakeLists.txt @@ -1,12 +1,12 @@ -foreach(_testname connectaddress pendingreply server threads) +foreach(_testname connectaddress errorpropagation pendingreply server threads) add_executable(tst_${_testname} tst_${_testname}.cpp) set_target_properties(tst_${_testname} PROPERTIES COMPILE_FLAGS -DTEST_DATADIR="\\"${CMAKE_CURRENT_SOURCE_DIR}\\"") target_link_libraries(tst_${_testname} testutil dfer) add_test(connection/${_testname} tst_${_testname}) endforeach() if (UNIX) target_link_libraries(tst_threads pthread) target_link_libraries(tst_server pthread) endif() diff --git a/tests/connection/tst_errorpropagation.cpp b/tests/connection/tst_errorpropagation.cpp new file mode 100644 index 0000000..cb77d5d --- /dev/null +++ b/tests/connection/tst_errorpropagation.cpp @@ -0,0 +1,179 @@ +/* + Copyright (C) 2018 Andreas Hartmetz + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Library General Public + License as published by the Free Software Foundation; either + version 2 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Library General Public License for more details. + + You should have received a copy of the GNU Library General Public License + along with this library; see the file COPYING.LGPL. If not, write to + the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, + Boston, MA 02110-1301, USA. + + Alternatively, this file is available under the Mozilla Public License + Version 1.1. You may obtain a copy of the License at + http://www.mozilla.org/MPL/ +*/ + +#include "arguments.h" +#include "connectaddress.h" +#include "error.h" +#include "eventdispatcher.h" +#include "imessagereceiver.h" +#include "message.h" +#include "pendingreply.h" +#include "testutil.h" +#include "connection.h" + +#include + +static const char *s_testMethod = "dferryTestingMethod"; + +class ReplierReceiver : public IMessageReceiver +{ +public: + void handleSpontaneousMessageReceived(Message msg, Connection *connection) override + { + std::cerr << " Replier here. Yo, got it!" << std::endl; + // we're on the session bus, so we'll receive all kinds of notifications we don't care about here + if (msg.type() != Message::MethodCallMessage || msg.method() != s_testMethod) { + return; + } + // TODO also generate a malformed reply and see what happens + Message reply = Message::createReplyTo(msg); + connection->sendNoReply(std::move(reply)); + } +}; + +enum { StepsCount = 10 }; + +void test_errorPropagation() +{ + EventDispatcher eventDispatcher; + + // TODO do everything also with sendNoReply() + + for (int errorAtStep = 0; errorAtStep < StepsCount; errorAtStep++) { + + Connection conn(&eventDispatcher, ConnectAddress::StandardBus::Session); + conn.setDefaultReplyTimeout(500); + conn.waitForConnectionEstablished(); + TEST(conn.isConnected()); + + ReplierReceiver replier; + conn.setSpontaneousMessageReceiver(&replier); + + Arguments::Writer writer; + writer.beginVariant(); + + // If errorAtStep == 0, we do NOT introduce an error, just to check that the intentional + // errors are the only ones + + // The following pattern will repeat for every step where an error can be introduced + if (errorAtStep != 1) { + // Do it right + writer.writeUint32(0); + writer.endVariant(); + } else { + // Introduce an error + writer.endVariant(); // a variant may not be empty + } + + if (errorAtStep == 2) { + // too many file descriptors, we "magically" know that the max number of allowed file + // descriptors is 16. TODO it should be possible to ask the Connection about it(?) + for (int i = 0; i < 17; i++) { + // bogus file descriptors, shouldn't matter: the error should occur before they might + // possibly need to be valid + writer.writeUnixFd(100000); + } + } + + Message msg; + if (errorAtStep != 3) { + msg.setType(Message::MethodCallMessage); + } + + // not adding arguments to produce an error won't work - a call without arguments is fine! + msg.setArguments(writer.finish()); + + if (errorAtStep != 4) { + msg.setDestination(conn.uniqueName()); + } + if (errorAtStep != 5) { + msg.setPath("/foo/bar/dferry/testing"); + } + if (errorAtStep != 6) { + msg.setMethod(s_testMethod); + } + // Note interface is optional, so we can't introduce an error by omitting it (except with a signal, + // but we don't test signals) + + if (errorAtStep == 7) { + conn.close(); + } + + PendingReply reply = conn.send(std::move(msg)); + + if (errorAtStep == 8) { + // Since we haven't sent any (non-internal) messages yet, we rely on the send going through + // immediately, but the receive should fail due to this disconnect. + conn.close(); + } + + while (!reply.isFinished()) { + eventDispatcher.poll(); + } + +/* +Sources of error yet to do: +Message too large, other untested important Message properties? +Error reply from other side +Timeout +Malformed reply? +Malformed reply arguments? + + */ + + static const Error::Code expectedErrors[StepsCount] = { + Error::NoError, + Error::EmptyVariant, + Error::SendingTooManyUnixFds, + Error::MessageType, + Error::NoError, // TODO: probably wrong, message with no destination?! + Error::MessagePath, + Error::MessageMethod, + Error::LocalDisconnect, + Error::LocalDisconnect, + // TODO also test remote disconnect - or is that covered in another test? + Error::NoError // TODO: actually inject an error in that case + }; + + std::cerr << "Error at step " << errorAtStep << ": error code = " << reply.error().code() + << std::endl; + if (reply.reply()) { + std::cerr << " reply msg error code = " << reply.reply()->error().code() + << ", reply msg args error code = " << reply.reply()->arguments().error().code() + << std::endl; + } + + TEST(reply.error().code() == expectedErrors[errorAtStep]); + if (reply.reply()) { + TEST(reply.reply()->error().code() == expectedErrors[errorAtStep]); + TEST(reply.reply()->arguments().error().code() == expectedErrors[errorAtStep]); + } + + } +} + +int main(int, char *[]) +{ + test_errorPropagation(); + std::cout << "Passed!\n"; +} diff --git a/tests/connection/tst_threads.cpp b/tests/connection/tst_threads.cpp index 9174f63..bc8c488 100644 --- a/tests/connection/tst_threads.cpp +++ b/tests/connection/tst_threads.cpp @@ -1,238 +1,238 @@ /* Copyright (C) 2014 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #include "arguments.h" #include "connectaddress.h" #include "eventdispatcher.h" #include "imessagereceiver.h" #include "message.h" #include "pendingreply.h" #include "stringtools.h" #include "connection.h" #include "../testutil.h" #include #include #include static const char *echoPath = "/echo"; // make the name "fairly unique" because the interface name is our only protection against replying // to the wrong message static const char *echoInterface = "org.example_fb39a8dbd0aa66d2.echo"; static const char *echoMethod = "echo"; //////////////// Multi-thread ping-pong test //////////////// static const char *pingPayload = "-> J. Random PING"; static const char *pongPayload = "<- J. Random Pong"; class PongSender : public IMessageReceiver { public: void handleSpontaneousMessageReceived(Message ping, Connection *connection) override { if (ping.interface() != echoInterface) { // This is not the ping... it is probably still something from connection setup. // We can possibly receive many things here that we were not expecting. return; } { - Arguments args = ping.arguments(); + const Arguments args = ping.arguments(); Arguments::Reader reader(args); cstring payload = reader.readString(); TEST(!reader.error().isError()); TEST(reader.isFinished()); std::cout << "we have ping with payload: " << payload.ptr << std::endl; } { Message pong = Message::createReplyTo(ping); Arguments::Writer writer; writer.writeString(pongPayload); pong.setArguments(writer.finish()); std::cout << "\n\nSending pong!\n\n"; Error replyError = connection->sendNoReply(std::move(pong)); TEST(!replyError.isError()); connection->eventDispatcher()->interrupt(); } } }; static void pongThreadRun(Connection::CommRef mainConnectionRef, std::atomic *pongThreadReady) { std::cout << " Pong thread starting!\n"; EventDispatcher eventDispatcher; Connection conn(&eventDispatcher, std::move(mainConnectionRef)); PongSender pongSender; conn.setSpontaneousMessageReceiver(&pongSender); while (eventDispatcher.poll()) { std::cout << " Pong thread waking up!\n"; if (conn.uniqueName().length()) { pongThreadReady->store(true); // HACK: we do this only to wake up the main thread's event loop std::cout << "\n\nSending WAKEUP package!!\n\n"; Message wakey = Message::createCall(echoPath, "org.notexample.foo", echoMethod); wakey.setDestination(conn.uniqueName()); conn.sendNoReply(std::move(wakey)); } else { std::cout << " Pong thread: NO NAME YET!\n"; } // receive ping message // send pong message } std::cout << " Pong thread almost finished!\n"; } class PongReceiver : public IMessageReceiver { public: void handlePendingReplyFinished(PendingReply *pongReply, Connection *) override { TEST(!pongReply->error().isError()); Message pong = pongReply->takeReply(); Arguments args = pong.arguments(); Arguments::Reader reader(args); std::string strPayload = toStdString(reader.readString()); TEST(!reader.error().isError()); TEST(reader.isFinished()); TEST(strPayload == pongPayload); } }; static void testPingPong() { EventDispatcher eventDispatcher; Connection conn(&eventDispatcher, ConnectAddress::StandardBus::Session); std::atomic pongThreadReady(false); std::thread pongThread(pongThreadRun, conn.createCommRef(), &pongThreadReady); // finish creating the connection while (conn.uniqueName().empty()) { std::cout << "."; eventDispatcher.poll(); } std::cout << "we have connection! " << conn.uniqueName() << "\n"; // send ping message to other thread Message ping = Message::createCall(echoPath, echoInterface, echoMethod); Arguments::Writer writer; writer.writeString(pingPayload); ping.setArguments(writer.finish()); ping.setDestination(conn.uniqueName()); PongReceiver pongReceiver; PendingReply pongReply; bool sentPing = false; while (!sentPing || !pongReply.isFinished()) { eventDispatcher.poll(); if (pongThreadReady.load() && !sentPing) { std::cout << "\n\nSending ping!!\n\n"; pongReply = conn.send(std::move(ping)); pongReply.setReceiver(&pongReceiver); sentPing = true; } } TEST(pongReply.hasNonErrorReply()); std::cout << "we have pong!\n"; pongThread.join(); } //////////////// Multi-threaded timeout test //////////////// class TimeoutReceiver : public IMessageReceiver { public: void handlePendingReplyFinished(PendingReply *reply, Connection *) override { TEST(reply->isFinished()); TEST(!reply->hasNonErrorReply()); TEST(reply->error().code() == Error::Timeout); std::cout << "We HAVE timed out.\n"; } }; static void timeoutThreadRun(Connection::CommRef mainConnectionRef, std::atomic *done) { // TODO v turn this into proper documentation in Connection // Open a Connection "slaved" to the other Connection - it runs its own event loop in this thread // and has message I/O handled by the Connection in the "master" thread through message passing. // The main purpose of that is to use just one DBus connection per application( module), which is often // more convenient for client programmers and brings some limited ordering guarantees. std::cout << " Other thread starting!\n"; EventDispatcher eventDispatcher; Connection conn(&eventDispatcher, std::move(mainConnectionRef)); while (!conn.uniqueName().length()) { eventDispatcher.poll(); } Message notRepliedTo = Message::createCall(echoPath, echoInterface, echoMethod); notRepliedTo.setDestination(conn.uniqueName()); PendingReply deadReply = conn.send(std::move(notRepliedTo), 50); TimeoutReceiver timeoutReceiver; deadReply.setReceiver(&timeoutReceiver); while (!deadReply.isFinished()) { eventDispatcher.poll(); } *done = true; } static void testThreadedTimeout() { EventDispatcher eventDispatcher; Connection conn(&eventDispatcher, ConnectAddress::StandardBus::Session); std::atomic done(false); std::thread timeoutThread(timeoutThreadRun, conn.createCommRef(), &done); while (!done) { eventDispatcher.poll(); } timeoutThread.join(); } // more things to test: // - (do we want to do this, and if so here??) blocking on a reply through other thread's connection // - ping-pong with several messages queued - every message should arrive exactly once and messages // should arrive in sending order (can use serials for that as simplificitaion) int main(int, char *[]) { testPingPong(); testThreadedTimeout(); std::cout << "Passed!\n"; } diff --git a/transport/ipsocket.cpp b/transport/ipsocket.cpp index eb9e9cf..353173e 100644 --- a/transport/ipsocket.cpp +++ b/transport/ipsocket.cpp @@ -1,229 +1,233 @@ /* Copyright (C) 2015 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #include "ipsocket.h" #include "connectaddress.h" #ifdef __unix__ #include #include #include #include #include #endif #ifdef _WIN32 #include typedef SSIZE_T ssize_t; #endif #include #include #include #include #include #include -// HACK, put this somewhere else (get the value from original d-bus? or is it infinite?) -static const int maxFds = 12; - // TODO implement address family (IPv4 / IPv6) support IpSocket::IpSocket(const ConnectAddress &ca) : m_fd(-1) { assert(ca.type() == ConnectAddress::Type::Tcp); #ifdef _WIN32 WSAData wsadata; // IPv6 requires Winsock v2.0 or better (but we're not using IPv6 - yet!) if (WSAStartup(MAKEWORD(2, 0), &wsadata) != 0) { std::cerr << "IpSocket contruction failed A.\n"; return; } #endif const FileDescriptor fd = socket(AF_INET, SOCK_STREAM, 0); if (!isValidFileDescriptor(fd)) { std::cerr << "IpSocket contruction failed B.\n"; return; } struct sockaddr_in addr; addr.sin_family = AF_INET; addr.sin_port = htons(ca.port()); addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); bool ok = connect(fd, (struct sockaddr *)&addr, sizeof(addr)) == 0; // only make it non-blocking after connect() because Winsock returns // WSAEWOULDBLOCK when connecting a non-blocking socket #ifdef _WIN32 unsigned long value = 1; // 0 blocking, != 0 non-blocking if (ioctlsocket(fd, FIONBIO, &value) != NO_ERROR) { // something along the lines of... WS_ERROR_DEBUG(WSAGetLastError()); std::cerr << "IpSocket contruction failed C.\n"; closesocket(fd); return; } #else // don't let forks inherit the file descriptor - that can cause confusion... fcntl(fd, F_SETFD, FD_CLOEXEC); // To be able to use the same send() and recv() calls as Windows, also set the non-blocking // property on the socket descriptor here instead of passing MSG_DONTWAIT to send() and recv(). const int oldFlags = fcntl(fd, F_GETFL); if (oldFlags == -1) { ::close(fd); std::cerr << "IpSocket contruction failed D.\n"; return; } fcntl(fd, F_SETFL, oldFlags & O_NONBLOCK); #endif if (ok) { m_fd = fd; } else { #ifdef _WIN32 std::cerr << "IpSocket contruction failed E. Error is " << WSAGetLastError() << ".\n"; closesocket(fd); #else std::cerr << "IpSocket contruction failed E. Error is " << errno << ".\n"; ::close(fd); #endif } } IpSocket::IpSocket(FileDescriptor fd) : m_fd(fd) { } IpSocket::~IpSocket() { close(); #ifdef _WIN32 WSACleanup(); #endif } void IpSocket::platformClose() { if (isValidFileDescriptor(m_fd)) { #ifdef _WIN32 closesocket(m_fd); #else ::close(m_fd); #endif m_fd = InvalidFileDescriptor; } } IO::Result IpSocket::write(chunk a) { IO::Result ret; if (!isValidFileDescriptor(m_fd)) { std::cerr << "\nIpSocket::write() failed A.\n\n"; ret.status = IO::Status::InternalError; return ret; } const uint32 initialLength = a.length; while (a.length > 0) { ssize_t nbytes = send(m_fd, reinterpret_cast(a.ptr), a.length, 0); if (nbytes < 0) { if (errno == EINTR) { continue; } // see EAGAIN comment in LocalSocket::read() if (errno == EAGAIN) { break; } close(); ret.status = IO::Status::InternalError; return ret; + } else if (nbytes == 0) { + break; } a.ptr += nbytes; a.length -= uint32(nbytes); } ret.length = initialLength - a.length; return ret; } uint32 IpSocket::availableBytesForReading() { #ifdef _WIN32 u_long available = 0; if (ioctlsocket(m_fd, FIONREAD, &available) != NO_ERROR) { #else uint32 available = 0; if (ioctl(m_fd, FIONREAD, &available) < 0) { #endif available = 0; } return uint32(available); } IO::Result IpSocket::read(byte *buffer, uint32 maxSize) { IO::Result ret; if (maxSize <= 0) { std::cerr << "\nIpSocket::read() failed A.\n\n"; ret.status = IO::Status::InternalError; return ret; } while (maxSize > 0) { ssize_t nbytes = recv(m_fd, reinterpret_cast(buffer), maxSize, 0); if (nbytes < 0) { if (errno == EINTR) { continue; } // see comment in LocalSocket for rationale of EAGAIN behavior - if (errno == EAGAIN) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { break; } close(); - ret.status = IO::Status::RemoteClosed; // TODO - return ret; + ret.status = IO::Status::RemoteClosed; + break; + } else if (nbytes == 0) { + // orderly shutdown + close(); + ret.status = IO::Status::RemoteClosed; + break; } ret.length += uint32(nbytes); buffer += nbytes; maxSize -= uint32(nbytes); } return ret; } bool IpSocket::isOpen() { return isValidFileDescriptor(m_fd); } FileDescriptor IpSocket::fileDescriptor() const { return m_fd; } diff --git a/transport/itransport.cpp b/transport/itransport.cpp index 75f45b8..540ec8c 100644 --- a/transport/itransport.cpp +++ b/transport/itransport.cpp @@ -1,147 +1,154 @@ /* Copyright (C) 2013 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #include "itransport.h" #include "eventdispatcher.h" #include "eventdispatcher_p.h" #include "itransportlistener.h" #include "ipsocket.h" #include "connectaddress.h" #ifdef __unix__ #include "localsocket.h" #endif #include #include ITransport::ITransport() { } ITransport::~ITransport() { setReadListener(nullptr); setWriteListener(nullptr); } IO::Result ITransport::readWithFileDescriptors(byte *buffer, uint32 maxSize, std::vector *) { + // This may be OK - if the other side tried to send file descriptors we will later notice + // the difference between the message header declaring a nonzero FD count and the actual + // received FD count which will always be zero if we get here. IF the other side didn't try + // to send FDs, there is no problem. return read(buffer, maxSize); } -IO::Result ITransport::writeWithFileDescriptors(chunk data, const std::vector &) +IO::Result ITransport::writeWithFileDescriptors(chunk, const std::vector &) { - return write(data); + // Just don't call this on a transport that doesn't support passing file descriptors + IO::Result res; + res.status = IO::Status::LocalClosed; + return res; } void ITransport::setReadListener(ITransportListener *listener) { if (m_readListener != listener) { if (m_readListener) { m_readListener->m_readTransport = nullptr; } if (listener) { if (listener->m_readTransport) { listener->m_readTransport->setReadListener(nullptr); } assert(!listener->m_readTransport); listener->m_readTransport = this; } m_readListener = listener; } updateTransportIoInterest(); } void ITransport::setWriteListener(ITransportListener *listener) { if (m_writeListener != listener) { if (m_writeListener) { m_writeListener->m_writeTransport = nullptr; } if (listener) { if (listener->m_writeTransport) { listener->m_writeTransport->setWriteListener(nullptr); } assert(!listener->m_writeTransport); listener->m_writeTransport = this; } m_writeListener = listener; } updateTransportIoInterest(); } void ITransport::updateTransportIoInterest() { setIoInterest((m_readListener ? uint32(IO::RW::Read) : 0) | (m_writeListener ? uint32(IO::RW::Write) : 0)); } void ITransport::close() { if (!isOpen()) { return; } if (ioEventSource()) { ioEventSource()->removeIoListener(this); } platformClose(); } IO::Status ITransport::handleIoReady(IO::RW rw) { IO::Status ret = IO::Status::OK; assert(uint32(rw) & ioInterest()); // only get notified about events we requested if (rw == IO::RW::Read && m_readListener) { ret = m_readListener->handleTransportCanRead(); } else if (rw == IO::RW::Write && m_writeListener) { ret = m_writeListener->handleTransportCanWrite(); } else { assert(false); } if (ret != IO::Status::OK) { // TODO call some common close, cleanup & report error method } return ret; } //static ITransport *ITransport::create(const ConnectAddress &ci) { switch (ci.type()) { #ifdef __unix__ case ConnectAddress::Type::UnixPath: return new LocalSocket(ci.path()); case ConnectAddress::Type::AbstractUnixPath: // TODO this is Linux only, reflect it in code return new LocalSocket(std::string(1, '\0') + ci.path()); #endif case ConnectAddress::Type::Tcp: case ConnectAddress::Type::Tcp4: case ConnectAddress::Type::Tcp6: return new IpSocket(ci); default: assert(false); return nullptr; } } diff --git a/transport/itransport.h b/transport/itransport.h index 4eb9887..dd36e5a 100644 --- a/transport/itransport.h +++ b/transport/itransport.h @@ -1,82 +1,82 @@ /* Copyright (C) 2013 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #ifndef ITRANSPORT_H #define ITRANSPORT_H #include "iioeventlistener.h" #include "platform.h" #include "types.h" #include class ConnectAddress; class EventDispatcher; class ITransportListener; class SelectEventPoller; class ITransport : public IIoEventListener { public: // An ITransport subclass must have a file descriptor after construction and it must not change // except to the invalid file descriptor when disconnected. - ITransport(); // TODO event dispatcher as constructor argument? + ITransport(); ~ITransport() override; // This listener interface is different from IIoEventSource / IIoEventListener because that one is // one source, several file descriptors, one file descriptor to one listener // this is one file descriptor, two channels, one channel (read or write) to one listener. void setReadListener(ITransportListener *listener); void setWriteListener(ITransportListener *listener); virtual uint32 availableBytesForReading() = 0; virtual IO::Result read(byte *buffer, uint32 maxSize) = 0; virtual IO::Result readWithFileDescriptors(byte *buffer, uint32 maxSize, std::vector *fileDescriptors); virtual IO::Result write(chunk data) = 0; virtual IO::Result writeWithFileDescriptors(chunk data, const std::vector &fileDescriptors); void close(); virtual bool isOpen() = 0; - bool supportsPassingFileDescriptors() const { return m_supportsFileDescriptors; } + uint32 supportedPassingUnixFdsCount() const { return m_supportedUnixFdsCount; } IO::Status handleIoReady(IO::RW rw) override; // factory method - creates a suitable subclass to connect to address static ITransport *create(const ConnectAddress &connectAddress); protected: virtual void platformClose() = 0; - bool m_supportsFileDescriptors = false; + uint32 m_supportedUnixFdsCount = 0; private: void updateTransportIoInterest(); // "Transport" in name to avoid confusion with IIoEventSource friend class ITransportListener; friend class SelectEventPoller; ITransportListener *m_readListener = nullptr; ITransportListener *m_writeListener = nullptr; }; #endif // ITRANSPORT_H diff --git a/transport/localsocket.cpp b/transport/localsocket.cpp index 5c5dc74..c65db8f 100644 --- a/transport/localsocket.cpp +++ b/transport/localsocket.cpp @@ -1,348 +1,352 @@ /* Copyright (C) 2013 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #include "localsocket.h" #include #include #include #include #include "sys/uio.h" #include #include #include #include #include #include enum { // ### This is configurable in libdbus-1 but nobody ever seems to change it from the default of 16. MaxFds = 16, MaxFdPayloadSize = MaxFds * sizeof(int) }; LocalSocket::LocalSocket(const std::string &socketFilePath) : m_fd(-1) { - m_supportsFileDescriptors = true; + m_supportedUnixFdsCount = MaxFds; const int fd = socket(PF_UNIX, SOCK_STREAM, 0); if (fd < 0) { return; } // don't let forks inherit the file descriptor - that can cause confusion... fcntl(fd, F_SETFD, FD_CLOEXEC); struct sockaddr_un addr; addr.sun_family = PF_UNIX; bool ok = socketFilePath.length() + 1 <= sizeof(addr.sun_path); if (ok) { memcpy(addr.sun_path, socketFilePath.c_str(), socketFilePath.length() + 1); } ok = ok && (connect(fd, (struct sockaddr *)&addr, sizeof(sa_family_t) + socketFilePath.length()) == 0); if (ok) { m_fd = fd; } else { ::close(fd); } } LocalSocket::LocalSocket(int fd) : m_fd(fd) { } LocalSocket::~LocalSocket() { close(); } void LocalSocket::platformClose() { if (m_fd >= 0) { ::close(m_fd); m_fd = -1; } } IO::Result LocalSocket::write(chunk data) { IO::Result ret; if (data.length == 0) { return ret; } if (m_fd < 0) { ret.status = IO::Status::InternalError; return ret; } const uint32 initialLength = data.length; while (data.length > 0) { ssize_t nbytes = send(m_fd, data.ptr, data.length, MSG_DONTWAIT); if (nbytes < 0) { if (errno == EINTR) { continue; } // see EAGAIN comment in read() if (errno == EAGAIN || errno == EWOULDBLOCK) { break; } close(); ret.status = IO::Status::RemoteClosed; return ret; + } else if (nbytes == 0) { + break; } data.ptr += nbytes; data.length -= size_t(nbytes); } ret.length = initialLength - data.length; return ret; } // TODO: consider using iovec to avoid "copying together" message parts before sending; iovec tricks // are probably not going to help for receiving, though. IO::Result LocalSocket::writeWithFileDescriptors(chunk data, const std::vector &fileDescriptors) { IO::Result ret; if (data.length == 0) { return ret; } if (m_fd < 0) { ret.status = IO::Status::InternalError; return ret; } // sendmsg boilerplate struct msghdr send_msg; struct iovec iov; send_msg.msg_name = 0; send_msg.msg_namelen = 0; send_msg.msg_flags = 0; send_msg.msg_iov = &iov; send_msg.msg_iovlen = 1; iov.iov_base = data.ptr; iov.iov_len = data.length; // we can only send a fixed number of fds anyway due to the non-flexible size of the control message // receive buffer, so we set an arbitrary limit. const uint32 numFds = fileDescriptors.size(); if (fileDescriptors.size() > MaxFds) { // TODO allow a proper error return close(); ret.status = IO::Status::InternalError; return ret; } char cmsgBuf[CMSG_SPACE(MaxFdPayloadSize)]; const uint32 fdPayloadSize = numFds * sizeof(int); if (numFds) { // fill in a control message send_msg.msg_control = cmsgBuf; send_msg.msg_controllen = CMSG_SPACE(fdPayloadSize); struct cmsghdr *c_msg = CMSG_FIRSTHDR(&send_msg); c_msg->cmsg_len = CMSG_LEN(fdPayloadSize); c_msg->cmsg_level = SOL_SOCKET; c_msg->cmsg_type = SCM_RIGHTS; // set the control data to pass - this is why we don't use the simpler write() int *const fdPayload = reinterpret_cast(CMSG_DATA(c_msg)); for (uint32 i = 0; i < numFds; i++) { fdPayload[i] = fileDescriptors[i]; } } else { // no file descriptor to send, no control message send_msg.msg_control = nullptr; send_msg.msg_controllen = 0; } while (iov.iov_len > 0) { ssize_t nbytes = sendmsg(m_fd, &send_msg, MSG_DONTWAIT); if (nbytes < 0) { if (errno == EINTR) { continue; } // see EAGAIN comment in read() if (errno == EAGAIN || errno == EWOULDBLOCK) { break; } close(); ret.status = IO::Status::RemoteClosed; break; + } else if (nbytes == 0) { + break; } else if (nbytes > 0) { // control message already sent, don't send again send_msg.msg_control = nullptr; send_msg.msg_controllen = 0; } iov.iov_base = static_cast(iov.iov_base) + nbytes; iov.iov_len -= size_t(nbytes); } ret.length = data.length - iov.iov_len; return ret; } uint32 LocalSocket::availableBytesForReading() { uint32 available = 0; if (ioctl(m_fd, FIONREAD, &available) < 0) { available = 0; } return available; } IO::Result LocalSocket::read(byte *buffer, uint32 maxSize) { IO::Result ret; if (maxSize == 0) { return ret; } if (m_fd < 0) { ret.status = IO::Status::InternalError; return ret; } while (ret.length < maxSize) { ssize_t nbytes = recv(m_fd, buffer + ret.length, maxSize - ret.length, MSG_DONTWAIT); if (nbytes < 0) { if (errno == EINTR) { continue; } // If we were notified for reading directly by the event dispatcher, we must be able to read at // least one byte before getting AGAIN aka EWOULDBLOCK - *however* the event loop might notify // something that tries to read everything (like Message::notifyRead()...) by calling read() // in a loop, and in that case, we may be called in an attempt to read more when there is // currently no more data, and it's not an error. // Just return zero bytes and no error in that case. if (errno == EAGAIN || errno == EWOULDBLOCK) { break; } close(); ret.status = IO::Status::RemoteClosed; break; } else if (nbytes == 0) { // orderly shutdown close(); ret.status = IO::Status::RemoteClosed; return ret; } ret.length += size_t(nbytes); } return ret; } IO::Result LocalSocket::readWithFileDescriptors(byte *buffer, uint32 maxSize, std::vector *fileDescriptors) { IO::Result ret; if (maxSize == 0) { return ret; } if (m_fd < 0) { ret.status = IO::Status::InternalError; return ret; } // recvmsg-with-control-message boilerplate struct msghdr recv_msg; char cmsgBuf[CMSG_SPACE(sizeof(int) * MaxFds)]; recv_msg.msg_control = cmsgBuf; recv_msg.msg_controllen = CMSG_SPACE(MaxFdPayloadSize); memset(cmsgBuf, 0, recv_msg.msg_controllen); // prevent equivalent to CVE-2014-3635 in libdbus-1: We could receive and ignore an extra file // descriptor, thus eventually run out of file descriptors recv_msg.msg_controllen = CMSG_LEN(MaxFdPayloadSize); recv_msg.msg_name = 0; recv_msg.msg_namelen = 0; recv_msg.msg_flags = 0; struct iovec iov; recv_msg.msg_iov = &iov; recv_msg.msg_iovlen = 1; // end boilerplate iov.iov_base = buffer; iov.iov_len = maxSize; while (iov.iov_len > 0) { ssize_t nbytes = recvmsg(m_fd, &recv_msg, MSG_DONTWAIT); if (nbytes < 0) { if (errno == EINTR) { continue; } // see comment in read() if (errno == EAGAIN || errno == EWOULDBLOCK) { break; } close(); ret.status = IO::Status::RemoteClosed; break; } else if (nbytes == 0) { // orderly shutdown close(); ret.status = IO::Status::RemoteClosed; break; } else { // read any file descriptors passed via control messages struct cmsghdr *c_msg = CMSG_FIRSTHDR(&recv_msg); if (c_msg && c_msg->cmsg_level == SOL_SOCKET && c_msg->cmsg_type == SCM_RIGHTS) { const int count = (c_msg->cmsg_len - CMSG_LEN(0)) / sizeof(int); const int *const fdPayload = reinterpret_cast(CMSG_DATA(c_msg)); for (int i = 0; i < count; i++) { fileDescriptors->push_back(fdPayload[i]); } } // control message already received, don't receive another recv_msg.msg_control = nullptr; recv_msg.msg_controllen = 0; } ret.length += size_t(nbytes); iov.iov_base = static_cast(iov.iov_base) + nbytes; iov.iov_len -= size_t(nbytes); } return ret; } bool LocalSocket::isOpen() { return m_fd != -1; } int LocalSocket::fileDescriptor() const { return m_fd; } diff --git a/util/error.h b/util/error.h index bcadfc1..346836a 100644 --- a/util/error.h +++ b/util/error.h @@ -1,141 +1,144 @@ /* Design notes about errors. Errors can come (including but not limited to...) from these areas: - Arguments assembly - invalid construct, e.g. empty struct, dict with key but no value, dict with invalid key type, writing different (non-variant) types in different array elements - limit exceeded (message size, nesting depth etc) - invalid single data (e.g. null in string, too long string) - Arguments disassembly - malformed data (mostly manifesting as limit exceeded, since the format has little room for "grammar errors" - almost everything could theoretically be valid data) - invalid single data - trying to read something incompatible with reader state - Message assembly - required headers not present - Message disassembly - required headers not present (note: sender header in bus connections! not currently checked.) - I/O errors - could not open connection - any sub-codes? - disconnected - timeout?? - (read a malformed message - connection should be closed) - discrepancy in number of file descriptors advertised and actually received - when this is implemented - artifacts of the implementation; not much - using a default-constructed PendingReply, anything else? - error codes from standardized DBus interfaces like the introspection thing; I think the convenience stuff should really be separate! Maybe separate namespace, in any case separate enum an error (if any) propagates in the following way, so you don't need to check at every step: Arguments::Writer -> Arguments -> Message -> PendingReply */ #ifndef ERROR_H #define ERROR_H #include "types.h" #include class DFERRY_EXPORT Error { public: enum Code : uint32 { // Error error ;) NoError = 0, // Arguments errors NotAttachedToArguments, InvalidSignature, ReplacementDataIsShorter, MalformedMessageData, ReadWrongType, NotPrimitiveType, InvalidType, InvalidString, InvalidObjectPath, SignatureTooLong, ExcessiveNesting, CannotEndArgumentsHere, ArgumentsTooLong, NotSingleCompleteTypeInVariant, EmptyVariant, CannotEndVariantHere, EmptyStruct, CannotEndStructHere, NotSingleCompleteTypeInArray, TypeMismatchInSubsequentArrayIteration, CannotEndArrayHere, CannotEndArrayOrDictHere, TooFewTypesInArrayOrDict, InvalidStateToRestartEmptyArray, InvalidKeyTypeInDict, GreaterTwoTypesInDict, ArrayOrDictTooLong, + SendingTooManyUnixFds, // The FD capacity varies by transport, so this error is only produced + // when trying to send a message with too many FDs. It is fine to pass + // around a message with lots of file descriptors locally. StateNotSkippable, MissingBeginDictEntry = 1019, MisplacedBeginDictEntry, MissingEndDictEntry, MisplacedEndDictEntry, // we have a lot of error codes at our disposal, so reserve some for easy classification // by range MaxArgumentsError = 1023, // end Arguments errors // Message / PendingReply DetachedPendingReply, Timeout, RemoteDisconnect, LocalDisconnect, MalformedReply, // Catch-all for failed reply validation - can't be corrected locally anyway. // Since the reply isn't fully pre-validated for performance reasons, // absence of this error is no guarantee of well-formedness. MessageType, // ||| all of these may potentially mean missing for the type of message MessageSender, // vvv or locally found to be invalid (invalid object path for example) MessageDestination, MessagePath, MessageInterface, MessageSignature, MessageMethod, MessageErrorName, MessageSerial, MessageReplySerial, MessageProtocolVersion, PeerNoSuchReceiver, PeerNoSuchPath, PeerNoSuchInterface, PeerNoSuchMethod, ArgumentTypeMismatch, PeerInvalidProperty, PeerNoSuchProperty, AccessDenied, // for now(?) only properties: writing to read-only / reading from write-only MaxMessageError = 2047 // end Message / PendingReply errors // errors for other occasions go here }; Error() : m_code(NoError) {} Error(Code code) : m_code(code) {} void setCode(Code code) { m_code = code; } Code code() const { return m_code; } bool isError() const { return m_code != NoError; } // no setter for message - it is just looked up from a static table according to error code std::string message() const; private: Code m_code; }; #endif // ERROR_H diff --git a/util/iovaluetypes.h b/util/iovaluetypes.h index d90e636..214215b 100644 --- a/util/iovaluetypes.h +++ b/util/iovaluetypes.h @@ -1,56 +1,57 @@ /* Copyright (C) 2018 Andreas Hartmetz This library is free software; you can redistribute it and/or modify it under the terms of the GNU Library General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License for more details. You should have received a copy of the GNU Library General Public License along with this library; see the file COPYING.LGPL. If not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. Alternatively, this file is available under the Mozilla Public License Version 1.1. You may obtain a copy of the License at http://www.mozilla.org/MPL/ */ #ifndef IOVALUETYPES_H #define IOVALUETYPES_H #include "types.h" namespace IO { // would be nice to wrap this into a type-safe bitset enum / class, but since it's for // internal use, uint32 is okay... enum class RW { Read = 1, Write = 2, }; enum class Status { OK = 0, RemoteClosed, LocalClosed, + PayloadError, InternalError }; struct Result { Status status = Status::OK; uint32 length = 0; }; } // namespace IO #endif // IOVALUETYPES_H