// Copyright (c) 2012-2015, The CryptoNote developers, The Bytecoin developers // // This file is part of Bytecoin. // // Bytecoin is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // Bytecoin is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with Bytecoin. If not, see . #include "TcpConnection.h" #include #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif #include #include #include #include #include "Dispatcher.h" namespace System { namespace { struct TcpConnectionContext : public OVERLAPPED { void* context; bool interrupted; }; } TcpConnection::TcpConnection() : dispatcher(nullptr) { } TcpConnection::TcpConnection(TcpConnection&& other) : dispatcher(other.dispatcher) { if (dispatcher != nullptr) { assert(other.readContext == nullptr); assert(other.writeContext == nullptr); connection = other.connection; stopped = other.stopped; readContext = nullptr; writeContext = nullptr; other.dispatcher = nullptr; } } TcpConnection::~TcpConnection() { if (dispatcher != nullptr) { assert(readContext == nullptr); assert(writeContext == nullptr); int result = closesocket(connection); assert(result == 0); } } TcpConnection& TcpConnection::operator=(TcpConnection&& other) { if (dispatcher != nullptr) { assert(readContext == nullptr); assert(writeContext == nullptr); if (closesocket(connection) != 0) { throw std::runtime_error("TcpConnection::operator=, closesocket failed, result=" + std::to_string(WSAGetLastError())); } } dispatcher = other.dispatcher; if (dispatcher != nullptr) { assert(other.readContext == nullptr); assert(other.writeContext == nullptr); connection = other.connection; stopped = other.stopped; readContext = nullptr; writeContext = nullptr; other.dispatcher = nullptr; } return *this; } void TcpConnection::start() { assert(dispatcher != nullptr); assert(stopped); stopped = false; } void TcpConnection::stop() { assert(dispatcher != nullptr); assert(!stopped); if (readContext != nullptr) { TcpConnectionContext* context = static_cast(readContext); if (!context->interrupted) { if (CancelIoEx(reinterpret_cast(connection), context) != TRUE) { DWORD lastError = GetLastError(); if (lastError != ERROR_NOT_FOUND) { throw std::runtime_error("TcpConnection::stop, CancelIoEx failed, result=" + std::to_string(GetLastError())); } } context->interrupted = true; } } if (writeContext != nullptr) { TcpConnectionContext* context = static_cast(writeContext); if (!context->interrupted) { if (CancelIoEx(reinterpret_cast(connection), context) != TRUE) { DWORD lastError = GetLastError(); if (lastError != ERROR_NOT_FOUND) { throw std::runtime_error("TcpConnection::stop, CancelIoEx failed, result=" + std::to_string(GetLastError())); } } context->interrupted = true; } } stopped = true; } size_t TcpConnection::read(uint8_t* data, size_t size) { assert(dispatcher != nullptr); assert(readContext == nullptr); if (stopped) { throw InterruptedException(); } WSABUF buf{static_cast(size), reinterpret_cast(data)}; DWORD flags = 0; TcpConnectionContext context; context.hEvent = NULL; if (WSARecv(connection, &buf, 1, NULL, &flags, &context, NULL) != 0) { int lastError = WSAGetLastError(); if (lastError != WSA_IO_PENDING) { throw std::runtime_error("TcpConnection::read, WSARecv failed, result=" + std::to_string(lastError)); } } assert(flags == 0); context.context = GetCurrentFiber(); context.interrupted = false; readContext = &context; dispatcher->dispatch(); assert(context.context == GetCurrentFiber()); assert(dispatcher != nullptr); assert(readContext == &context); readContext = nullptr; DWORD transferred; if (WSAGetOverlappedResult(connection, &context, &transferred, FALSE, &flags) != TRUE) { int lastError = WSAGetLastError(); if (lastError != ERROR_OPERATION_ABORTED) { throw std::runtime_error("TcpConnection::read, WSARecv failed, result=" + std::to_string(lastError)); } assert(context.interrupted); throw InterruptedException(); } assert(transferred <= size); assert(flags == 0); return transferred; } std::size_t TcpConnection::write(const uint8_t* data, size_t size) { assert(dispatcher != nullptr); assert(writeContext == nullptr); if (stopped) { throw InterruptedException(); } if (size == 0) { if (shutdown(connection, SD_SEND) != 0) { throw std::runtime_error("TcpConnection::write, shutdown failed, result=" + std::to_string(WSAGetLastError())); } return 0; } WSABUF buf{static_cast(size), reinterpret_cast(const_cast(data))}; TcpConnectionContext context; context.hEvent = NULL; if (WSASend(connection, &buf, 1, NULL, 0, &context, NULL) != 0) { int lastError = WSAGetLastError(); if (lastError != WSA_IO_PENDING) { throw std::runtime_error("TcpConnection::write, WSASend failed, result=" + std::to_string(lastError)); } } context.context = GetCurrentFiber(); context.interrupted = false; writeContext = &context; dispatcher->dispatch(); assert(context.context == GetCurrentFiber()); assert(dispatcher != nullptr); assert(writeContext == &context); writeContext = nullptr; DWORD transferred; DWORD flags; if (WSAGetOverlappedResult(connection, &context, &transferred, FALSE, &flags) != TRUE) { int lastError = WSAGetLastError(); if (lastError != ERROR_OPERATION_ABORTED) { throw std::runtime_error("TcpConnection::write, WSASend failed, result=" + std::to_string(lastError)); } assert(context.interrupted); throw InterruptedException(); } assert(transferred == size); assert(flags == 0); return transferred; } std::pair TcpConnection::getPeerAddressAndPort() { sockaddr_in address; int size = sizeof(address); if (getpeername(connection, reinterpret_cast(&address), &size) != 0) { throw std::runtime_error("TcpConnection::getPeerAddress, getpeername failed, result=" + std::to_string(WSAGetLastError())); } assert(size == sizeof(sockaddr_in)); return std::make_pair(Ipv4Address(htonl(address.sin_addr.S_un.S_addr)), htons(address.sin_port)); } TcpConnection::TcpConnection(Dispatcher& dispatcher, std::size_t connection) : dispatcher(&dispatcher), connection(connection), stopped(false), readContext(nullptr), writeContext(nullptr) { } }