diff --git a/shell_integration/windows/OCShellExtensions/OCOverlays/OCOverlay.cpp b/shell_integration/windows/OCShellExtensions/OCOverlays/OCOverlay.cpp index 8e0265d67..967231d86 100644 --- a/shell_integration/windows/OCShellExtensions/OCOverlays/OCOverlay.cpp +++ b/shell_integration/windows/OCShellExtensions/OCOverlays/OCOverlay.cpp @@ -36,17 +36,13 @@ extern HINSTANCE instanceHandle; #define IDM_DISPLAY 0 #define IDB_OK 101 -namespace { - static std::vector s_watchedDirectories; -} - OCOverlay::OCOverlay(int state) - : _communicationSocket(0) - , _referenceCount(1) - , _checker(new RemotePathChecker(PORT)) + : _referenceCount(1) , _state(state) { + static RemotePathChecker s_remotePathChecker; + _checker = &s_remotePathChecker; } OCOverlay::~OCOverlay(void) @@ -121,23 +117,13 @@ IFACEMETHODIMP OCOverlay::GetPriority(int *pPriority) IFACEMETHODIMP OCOverlay::IsMemberOf(PCWSTR pwszPath, DWORD dwAttrib) { - - //if(!_IsOverlaysEnabled()) - //{ - // return MAKE_HRESULT(S_FALSE, 0, 0); - //} - - // FIXME: Use Registry instead, this will only trigger once - // and now follow any user changes in the client - if (s_watchedDirectories.empty()) { - s_watchedDirectories = _checker->WatchedDirectories(); - } + auto watchedDirectories = _checker->WatchedDirectories(); wstring wpath(pwszPath); - wpath.append(L"\\"); + //wpath.append(L"\\"); vector::iterator it; bool watched = false; - for (it = s_watchedDirectories.begin(); it != s_watchedDirectories.end(); ++it) { + for (it = watchedDirectories.begin(); it != watchedDirectories.end(); ++it) { if (StringUtil::begins_with(wpath, *it)) { watched = true; } diff --git a/shell_integration/windows/OCShellExtensions/OCOverlays/OCOverlay.h b/shell_integration/windows/OCShellExtensions/OCOverlays/OCOverlay.h index a84bc45b0..2e82114cd 100644 --- a/shell_integration/windows/OCShellExtensions/OCOverlays/OCOverlay.h +++ b/shell_integration/windows/OCShellExtensions/OCOverlays/OCOverlay.h @@ -35,14 +35,13 @@ public: IFACEMETHODIMP_(ULONG) Release(); protected: - ~OCOverlay(void); + ~OCOverlay(); private: //bool _GenerateMessage(const wchar_t*, std::wstring*); bool _IsOverlaysEnabled(); long _referenceCount; - CommunicationSocket* _communicationSocket; RemotePathChecker* _checker; int _state; }; diff --git a/shell_integration/windows/OCShellExtensions/OCUtil/CommunicationSocket.cpp b/shell_integration/windows/OCShellExtensions/OCUtil/CommunicationSocket.cpp index a2f4a9413..13494d928 100644 --- a/shell_integration/windows/OCShellExtensions/OCUtil/CommunicationSocket.cpp +++ b/shell_integration/windows/OCShellExtensions/OCUtil/CommunicationSocket.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -30,8 +31,8 @@ using namespace std; #define DEFAULT_BUFLEN 4096 -CommunicationSocket::CommunicationSocket(int port) - : _port(port), _clientSocket(INVALID_SOCKET) +CommunicationSocket::CommunicationSocket() + : _pipe(INVALID_HANDLE_VALUE) { } @@ -43,64 +44,42 @@ CommunicationSocket::~CommunicationSocket() bool CommunicationSocket::Close() { WSACleanup(); - bool closed = (closesocket(_clientSocket) == 0); - shutdown(_clientSocket, SD_BOTH); - _clientSocket = INVALID_SOCKET; - return closed; + if (_pipe == INVALID_HANDLE_VALUE) { + return false; + } + CloseHandle(_pipe); + _pipe = INVALID_HANDLE_VALUE; + return true; } bool CommunicationSocket::Connect() { - WSADATA wsaData; + auto pipename = std::wstring(L"\\\\.\\pipe\\"); + pipename += L"ownCloud"; - HRESULT iResult = WSAStartup(MAKEWORD(2, 2), &wsaData); + _pipe = CreateFile(pipename.data(), GENERIC_READ | GENERIC_WRITE, 0, NULL, OPEN_EXISTING, 0, NULL); - if (iResult != NO_ERROR) { - int error = WSAGetLastError(); - } - - - _clientSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - - if (_clientSocket == INVALID_SOCKET) { - //int error = WSAGetLastError(); - Close(); - return false; - } - - struct sockaddr_in clientService; - - clientService.sin_family = AF_INET; - clientService.sin_addr.s_addr = inet_addr(PLUG_IN_SOCKET_ADDRESS); - clientService.sin_port = htons(_port); - - iResult = connect(_clientSocket, (SOCKADDR*)&clientService, sizeof(clientService)); - DWORD timeout = 500; // ms - setsockopt(_clientSocket, SOL_SOCKET, SO_RCVTIMEO, (const char*) &timeout, sizeof(DWORD)); - - if (iResult == SOCKET_ERROR) { - //int error = WSAGetLastError(); - Close(); - return false; - } - return true; + if (_pipe == INVALID_HANDLE_VALUE) { + return false; + } + return true; } bool CommunicationSocket::SendMsg(const wchar_t* message) { - const char* utf8_msg = StringUtil::toUtf8(message); - size_t result = send(_clientSocket, utf8_msg, (int)strlen(utf8_msg), 0); - delete[] utf8_msg; + auto utf8_msg = StringUtil::toUtf8(message); - if (result == SOCKET_ERROR) { - //int error = WSAGetLastError(); - closesocket(_clientSocket); - return false; - } - - return true; + DWORD numBytesWritten = 0; + auto result = WriteFile( _pipe, utf8_msg.c_str(), DWORD(utf8_msg.size()), &numBytesWritten, NULL); + if (result) { + return true; + } else { +// qWarning() << "Failed to send data." <<; + // look up error code here using GetLastError() + return false; + } } bool CommunicationSocket::ReadLine(wstring* response) @@ -109,21 +88,36 @@ bool CommunicationSocket::ReadLine(wstring* response) return false; } - vector resp_utf8; - char buffer; + response->clear(); + + Sleep(50); + while (true) { - int bytesRead = recv(_clientSocket, &buffer, 1, 0); - if (bytesRead <= 0) { - response = 0; + int lbPos = 0; + auto it = std::find(_buffer.begin() + lbPos, _buffer.end(), '\n'); + if (it != _buffer.end()) { + *response = StringUtil::toUtf16(_buffer.data(), DWORD(it - _buffer.begin())); + _buffer.erase(_buffer.begin(), it + 1); + return true; + } + + std::array resp_utf8; + DWORD numBytesRead = 0; + DWORD totalBytesAvailable = 0; + PeekNamedPipe(_pipe, NULL, 0, 0, &totalBytesAvailable, 0); + if (totalBytesAvailable == 0) { return false; } - if (buffer == '\n') { - resp_utf8.push_back(0); - *response = StringUtil::toUtf16(&resp_utf8[0], resp_utf8.size()); - return true; - } else { - resp_utf8.push_back(buffer); - } + auto result = ReadFile(_pipe, resp_utf8.data(), DWORD(resp_utf8.size()), &numBytesRead, NULL); + if (!result) { +// qWarning() << "Failed to read data from the pipe"; + return false; + } + if (numBytesRead <= 0) { + return false; + } + _buffer.insert(_buffer.end(), resp_utf8.begin(), resp_utf8.begin()+numBytesRead); + continue; } } diff --git a/shell_integration/windows/OCShellExtensions/OCUtil/CommunicationSocket.h b/shell_integration/windows/OCShellExtensions/OCUtil/CommunicationSocket.h index 7bbf1bf63..7e9cd118f 100644 --- a/shell_integration/windows/OCShellExtensions/OCUtil/CommunicationSocket.h +++ b/shell_integration/windows/OCShellExtensions/OCUtil/CommunicationSocket.h @@ -20,12 +20,13 @@ #pragma warning (disable : 4251) #include +#include #include class __declspec(dllexport) CommunicationSocket { public: - CommunicationSocket(int port); + CommunicationSocket(); ~CommunicationSocket(); bool Connect(); @@ -34,9 +35,11 @@ public: bool SendMsg(const wchar_t*); bool ReadLine(std::wstring*); + HANDLE Event() { return _pipe; } + private: - int _port; - SOCKET _clientSocket; + HANDLE _pipe; + std::vector _buffer; }; #endif \ No newline at end of file diff --git a/shell_integration/windows/OCShellExtensions/OCUtil/RemotePathChecker.cpp b/shell_integration/windows/OCShellExtensions/OCUtil/RemotePathChecker.cpp index a6d9cce28..5967245fc 100644 --- a/shell_integration/windows/OCShellExtensions/OCUtil/RemotePathChecker.cpp +++ b/shell_integration/windows/OCShellExtensions/OCUtil/RemotePathChecker.cpp @@ -20,88 +20,120 @@ #include #include #include +#include +#include + +#include using namespace std; -RemotePathChecker::RemotePathChecker(int port) - : _port(port) + +// This code is run in a thread +void RemotePathChecker::workerThreadLoop() { + CommunicationSocket socket; + std::unordered_set asked; + if (!socket.Connect()) { + return; + //FIXME! what if this fails! what if we are disconnected later? + } + + while(!_stop) { + { + std::unique_lock lock(_mutex); + while (!_pending.empty() && !_stop) { + auto filePath = _pending.front(); + _pending.pop(); + + lock.unlock(); + if (!asked.count(filePath)) { + asked.insert(filePath); + socket.SendMsg(wstring(L"RETRIEVE_FILE_STATUS:" + filePath + L'\n').data()); + } + lock.lock(); + } + } + + std::wstring response; + while (!_stop && socket.ReadLine(&response)) { + if (StringUtil::begins_with(response, wstring(L"REGISTER_PATH:"))) { + wstring responsePath = response.substr(14); // length of REGISTER_PATH: + + std::unique_lock lock(_mutex); + _watchedDirectories.push_back(responsePath); + } else if (StringUtil::begins_with(response, wstring(L"STATUS:")) || + StringUtil::begins_with(response, wstring(L"BROADCAST:"))) { + + auto statusBegin = response.find(L':', 0); + assert(statusBegin != std::wstring::npos); + + auto statusEnd = response.find(L':', statusBegin + 1); + if (statusEnd == std::wstring::npos) { + // the command do not contains two colon? + continue; + } + + auto responseStatus = response.substr(statusBegin+1, statusEnd - statusBegin-1); + auto responsePath = response.substr(statusEnd+1); + auto state = _StrToFileState(responseStatus); + auto erased = asked.erase(responsePath); + + { std::unique_lock lock(_mutex); + _cache[responsePath] = state; + } + SHChangeNotify(SHCNE_MKDIR, SHCNF_PATH, responsePath.data(), NULL); + } + } + + if (_stop) + return; + } +} + + + +RemotePathChecker::RemotePathChecker() + : _thread([this]{ this->workerThreadLoop(); } ) + , _newQueries(CreateEvent(NULL, true, true, NULL)) +{ +} + +RemotePathChecker::~RemotePathChecker() +{ + _stop = true; + //_newQueries.notify_all(); + SetEvent(_newQueries); + _thread.join(); + CloseHandle(_newQueries); } vector RemotePathChecker::WatchedDirectories() { - vector watchedDirectories; - wstring response; - bool needed = false; - - CommunicationSocket socket(_port); - socket.Connect(); - - while (socket.ReadLine(&response)) { - if (StringUtil::begins_with(response, wstring(L"REGISTER_PATH:"))) { - size_t pathBegin = response.find(L':', 0); - if (pathBegin == -1) { - continue; - } - - // chop trailing '\n' - wstring responsePath = response.substr(pathBegin + 1, response.length()-1); - watchedDirectories.push_back(responsePath); - } - } - - return watchedDirectories; + std::unique_lock lock(_mutex); + return _watchedDirectories; } bool RemotePathChecker::IsMonitoredPath(const wchar_t* filePath, int* state) { - wstring request; - wstring response; - bool needed = false; + assert(state); assert(filePath); - CommunicationSocket socket(_port); - socket.Connect(); - request = L"RETRIEVE_FILE_STATUS:"; - request += filePath; - request += L'\n'; + std::unique_lock lock(_mutex); - if (!socket.SendMsg(request.c_str())) { - return false; - } + auto path = std::wstring(filePath); - while (socket.ReadLine(&response)) { - // discard broadcast messages - if (StringUtil::begins_with(response, wstring(L"STATUS:"))) { - break; - } - } + auto it = _cache.find(path); + if (it != _cache.end()) { + *state = it->second; + return true; + } - size_t statusBegin = response.find(L':', 0); - if (statusBegin == -1) - return false; + _pending.push(filePath); + SetEvent(_newQueries); + return false; - size_t statusEnd = response.find(L':', statusBegin + 1); - if (statusEnd == -1) - return false; - - - wstring responseStatus = response.substr(statusBegin+1, statusEnd - statusBegin-1); - wstring responsePath = response.substr(statusEnd+1); - if (responsePath == filePath) { - if (!state) { - return false; - } - *state = _StrToFileState(responseStatus); - if (*state == StateNone) { - return false; - } - needed = true; - } - - return needed; } -int RemotePathChecker::_StrToFileState(const std::wstring &str) +RemotePathChecker::FileState RemotePathChecker::_StrToFileState(const std::wstring &str) { if (str == L"NOP" || str == L"NONE") { return StateNone; diff --git a/shell_integration/windows/OCShellExtensions/OCUtil/RemotePathChecker.h b/shell_integration/windows/OCShellExtensions/OCUtil/RemotePathChecker.h index 1fb1ef9f3..5c4c43f98 100644 --- a/shell_integration/windows/OCShellExtensions/OCUtil/RemotePathChecker.h +++ b/shell_integration/windows/OCShellExtensions/OCUtil/RemotePathChecker.h @@ -16,6 +16,12 @@ #include #include +#include +#include +#include +#include +#include +#include #pragma once @@ -29,14 +35,32 @@ public: StateWarning, StateWarningSWM, StateNone }; - RemotePathChecker(int port); + RemotePathChecker(); + ~RemotePathChecker(); std::vector WatchedDirectories(); bool IsMonitoredPath(const wchar_t* filePath, int* state); private: - int _StrToFileState(const std::wstring &str); - int _port; + FileState _StrToFileState(const std::wstring &str); + std::mutex _mutex; + std::thread _thread; + std::atomic _stop; + // Everything here is protected by the _mutex + + /** The list of paths we need to query. The main thread fill this, and the worker thread + * send that to the socket. */ + std::queue _pending; + + std::unordered_map _cache; + std::vector _watchedDirectories; + + + // The main thread notifies when there are new items in _pending + //std::condition_variable _newQueries; + HANDLE _newQueries; + + void workerThreadLoop(); }; #endif \ No newline at end of file diff --git a/shell_integration/windows/OCShellExtensions/OCUtil/StringUtil.cpp b/shell_integration/windows/OCShellExtensions/OCUtil/StringUtil.cpp index 0c7af0b82..aa7ad5569 100644 --- a/shell_integration/windows/OCShellExtensions/OCUtil/StringUtil.cpp +++ b/shell_integration/windows/OCShellExtensions/OCUtil/StringUtil.cpp @@ -11,22 +11,26 @@ * details. */ -#include +#include +#include +#include #include "StringUtil.h" -char* StringUtil::toUtf8(const wchar_t *utf16, int len) +std::string StringUtil::toUtf8(const wchar_t *utf16, int len) { - int newlen = WideCharToMultiByte(CP_UTF8, 0, utf16, len, NULL, 0, NULL, NULL); - char* str = new char[newlen]; - WideCharToMultiByte(CP_UTF8, 0, utf16, -1, str, newlen, NULL, NULL); - return str; + if (len < 0) { + len = wcslen(utf16); + } + std::wstring_convert > converter; + return converter.to_bytes(utf16, utf16+len); } -wchar_t* StringUtil::toUtf16(const char *utf8, int len) +std::wstring StringUtil::toUtf16(const char *utf8, int len) { - int newlen = MultiByteToWideChar(CP_UTF8, 0, utf8, len, NULL, 0); - wchar_t* wstr = new wchar_t[newlen]; - MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wstr, newlen); - return wstr; + if (len < 0) { + len = strlen(utf8); + } + std::wstring_convert > converter; + return converter.from_bytes(utf8, utf8+len); } diff --git a/shell_integration/windows/OCShellExtensions/OCUtil/StringUtil.h b/shell_integration/windows/OCShellExtensions/OCUtil/StringUtil.h index cf6ec8606..d64eda7e5 100644 --- a/shell_integration/windows/OCShellExtensions/OCUtil/StringUtil.h +++ b/shell_integration/windows/OCShellExtensions/OCUtil/StringUtil.h @@ -20,15 +20,14 @@ class __declspec(dllexport) StringUtil { public: - static char* toUtf8(const wchar_t* utf16, int len = -1); - static wchar_t* toUtf16(const char* utf8, int len = -1); - + static std::string toUtf8(const wchar_t* utf16, int len = -1); + static std::wstring toUtf16(const char* utf8, int len = -1); template static bool begins_with(const T& input, const T& match) { return input.size() >= match.size() - && equal(match.begin(), match.end(), input.begin()); + && std::equal(match.begin(), match.end(), input.begin()); } };