#include "tls.h" #if defined(_WIN32) #define TF_TLS_SCHANNEL #elif defined(__MACH__) #define TF_TLS_APPLE #else #define TF_TLS_OPENSSL #endif #include #include #if defined(TF_TLS_OPENSSL) #include #include #include #include #include #include #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif typedef enum _direction_t { k_direction_undetermined, k_direction_accept, k_direction_connect, } direction_t; typedef struct _tf_tls_context_t { #if defined(TF_TLS_OPENSSL) SSL_CTX* context; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } tf_tls_context_t; typedef struct _tf_tls_session_t { #if defined(TF_TLS_OPENSSL) tf_tls_context_t* context; BIO* bio_in; BIO* bio_out; SSL* ssl; const char* hostname; direction_t direction; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } tf_tls_session_t; #if !defined (_WIN32) && !defined (__MACH__) #elif defined (__MACH__) #include #include #include #include #include #include extern "C" SecIdentityRef SecIdentityCreate(CFAllocatorRef allocator, SecCertificateRef certificate, SecKeyRef privateKey); class TlsContext_osx : public TlsContext { public: ~TlsContext_osx() override; TlsSession* createSession() override; bool setCertificate(const char* certificate) override; bool setPrivateKey(const char* privateKey) override; bool addTrustedCertificate(const char* certificate) override; SecKeyRef& getPrivateKey() { return _privateKey; } SecCertificateRef& getCertificate() { return _certificate; CFArrayRef getTrustedCertificates() { return _trustedCertificates; } private: SecKeyRef _privateKey = 0; SecCertificateRef _certificate = 0; CFMutableArrayRef _trustedCertificates = 0; }; class TlsSession_osx : public TlsSession { public: TlsSession_osx(TlsContext_osx* context); ~TlsSession_osx(); void startConnect() override; void startAccept() override; void shutdown() override; HandshakeResult handshake() override; int readPlain(char* buffer, size_t bytes) override; int writePlain(const char* buffer, size_t bytes) override; int readEncrypted(char* buffer, size_t bytes) override; int writeEncrypted(const char* buffer, size_t bytes) override; void setHostname(const char* hostname) override { _hostname = hostname; } virtual int getPeerCertificate(char* buffer, size_t bytes) override; private: static OSStatus writeCallback(SSLConnectionRef connection, const void* data, size_t* dataLength); static OSStatus readCallback(SSLConnectionRef connection, void* data, size_t* dataLength); TlsContext_osx* _context = 0; SSLContextRef _session = 0; std::vector _inBuffer; std::vector _outBuffer; std::string _hostname; bool _shutdown = false; }; TlsContext_osx::~TlsContext_osx() { if (_privateKey) { CFRelease(_privateKey); _privateKey = 0; } if (_certificate) { CFRelease(_certificate); _certificate = 0; } if (_trustedCertificates) { CFRelease(_trustedCertificates); _trustedCertificates = 0; } } TlsSession* TlsContext_osx::createSession() { return new TlsSession_osx(this); } bool TlsContext_osx::setCertificate(const char* certificate) { if (_certificate) { CFRelease(_certificate); _certificate = 0; } CFDataRef data = CFDataCreateWithBytesNoCopy(kCFAllocatorDefault, reinterpret_cast(certificate), std::strlen(certificate), kCFAllocatorDefault); CFArrayRef items = 0; SecExternalFormat format = kSecFormatPEMSequence; SecExternalItemType itemType = kSecItemTypeCertificate; OSStatus status = SecItemImport(data, 0, &format, &itemType, 0, 0, 0, &items); if (status == noErr && CFArrayGetCount(items) > 0) { _certificate = (SecCertificateRef)CFArrayGetValueAtIndex(items, 0); } return _certificate != 0; } bool TlsContext_osx::setPrivateKey(const char* privateKey) { if (_privateKey) { CFRelease(_privateKey); _privateKey = 0; } CFDataRef data = CFDataCreateWithBytesNoCopy(kCFAllocatorDefault, reinterpret_cast(privateKey), std::strlen(privateKey), kCFAllocatorDefault); CFArrayRef items = 0; SecExternalFormat format = kSecFormatPEMSequence; SecExternalItemType itemType = kSecItemTypePrivateKey; OSStatus status = SecItemImport(data, 0, &format, &itemType, 0, 0, 0, &items); if (status == noErr && CFArrayGetCount(items) > 0) { _privateKey = (SecKeyRef)CFArrayGetValueAtIndex(items, 0); } return _privateKey != 0; } bool TlsContext_osx::addTrustedCertificate(const char* certificate) { if (!_trustedCertificates) { _trustedCertificates = CFArrayCreateMutable(kCFAllocatorDefault, 0, &kCFTypeArrayCallBacks); } SecCertificateRef certificateItem = 0; CFDataRef data = CFDataCreateWithBytesNoCopy(kCFAllocatorDefault, reinterpret_cast(certificate), std::strlen(certificate), kCFAllocatorDefault); CFArrayRef items = 0; SecExternalFormat format = kSecFormatPEMSequence; SecExternalItemType itemType = kSecItemTypeCertificate; OSStatus status = SecItemImport(data, 0, &format, &itemType, 0, 0, 0, &items); if (status == noErr && CFArrayGetCount(items) > 0) { certificateItem = (SecCertificateRef)CFArrayGetValueAtIndex(items, 0); } if (certificateItem) { CFArrayAppendValue(_trustedCertificates, certificateItem); } return certificateItem != 0; } TlsSession_osx::TlsSession_osx(TlsContext_osx* context) { _context = context; } TlsSession_osx::~TlsSession_osx() { if (_session) { CFRelease(_session); _session = 0; } } void TlsSession_osx::startAccept() { _session = SSLCreateContext(0, kSSLServerSide, kSSLStreamType); if (_context->getCertificate() && _context->getPrivateKey()) { SecIdentityRef identity = SecIdentityCreate(kCFAllocatorDefault, _context->getCertificate(), _context->getPrivateKey()); CFArrayRef array = CFArrayCreate(kCFAllocatorDefault, (const void**)&identity, 1, &kCFTypeArrayCallBacks); SSLSetCertificate(_session, array); } SSLSetIOFuncs(_session, readCallback, writeCallback); SSLSetConnection(_session, this); handshake(); } void TlsSession_osx::startConnect() { _session = SSLCreateContext(0, kSSLClientSide, kSSLStreamType); if (_context->getTrustedCertificates()) { // XXX: SSLSetTrustedRoots(_session, _context->getTrustedCertificates(), false); } SSLSetIOFuncs(_session, readCallback, writeCallback); SSLSetConnection(_session, this); SSLSetPeerDomainName(_session, _hostname.c_str(), _hostname.size()); handshake(); } void TlsSession_osx::shutdown() { if (!_outBuffer.size()) { SSLClose(_session); _shutdown = false; } else { _shutdown = true; } } TlsSession::HandshakeResult TlsSession_osx::handshake() { TlsSession::HandshakeResult result = TlsSession::kFailed; OSStatus status = SSLHandshake(_session); switch (status) { case noErr: result = TlsSession::kDone; break; case errSSLWouldBlock: result = TlsSession::kMore; break; default: result = TlsSession::kFailed; break; } return result; } int TlsSession_osx::readPlain(char* buffer, size_t bytes) { int result = 0; size_t processed = bytes; OSStatus status = SSLRead(_session, buffer, bytes, &processed); if (status == noErr) { result = processed; } else if (status == errSSLWouldBlock) { result = processed; } else if (status == errSSLClosedGraceful) { result = kReadZero; } else { result = kReadFailed; } return result; } int TlsSession_osx::writePlain(const char* buffer, size_t bytes) { int result = 0; size_t processed; OSStatus status = SSLWrite(_session, buffer, bytes, &processed); if (status == noErr) { result = processed; } else { result = -1; } return result; } OSStatus TlsSession_osx::writeCallback(SSLConnectionRef connection, const void* data, size_t* dataLength) { TlsSession_osx* tls = reinterpret_cast(const_cast(connection)); tls->_outBuffer.insert(tls->_outBuffer.end(), reinterpret_cast(data), reinterpret_cast(data) + *dataLength); if (tls->_shutdown && !tls->_outBuffer.size()) { SSLClose(tls->_session); tls->_shutdown = false; } return noErr; } OSStatus TlsSession_osx::readCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { TlsSession_osx* tls = reinterpret_cast(const_cast(connection)); OSStatus result = noErr; size_t bytes = std::min(tls->_inBuffer.size(), *dataLength); if (bytes > 0) { std::memcpy(data, tls->_inBuffer.data(), bytes); tls->_inBuffer.erase(tls->_inBuffer.begin(), tls->_inBuffer.begin() + bytes); } if (bytes < *dataLength) { result = errSSLWouldBlock; } *dataLength = bytes; return result; } int TlsSession_osx::readEncrypted(char* buffer, size_t bytes) { size_t size = std::min(bytes, _outBuffer.size()); if (size > 0) { std::memcpy(buffer, _outBuffer.data(), size); _outBuffer.erase(_outBuffer.begin(), _outBuffer.begin() + size); } return size; } int TlsSession_osx::writeEncrypted(const char* buffer, size_t bytes) { _inBuffer.insert(_inBuffer.end(), buffer, buffer + bytes); return bytes; } int TlsSession_osx::getPeerCertificate(char* buffer, size_t size) { int result = -1; SecTrustRef trust = 0; if (SSLCopyPeerTrust(_session, &trust) == noErr) { if (SecTrustGetCertificateCount(trust) > 0) { SecCertificateRef certificate = SecTrustGetCertificateAtIndex(trust, 0); CFDataRef data = 0; if (SecItemExport(certificate, kSecFormatX509Cert, kSecItemPemArmour, nil, &data) == noErr) { size_t actualSize = CFDataGetLength(data); if (actualSize <= size) { CFDataGetBytes(data, CFRangeMake(0, actualSize), reinterpret_cast(buffer)); result = actualSize; } } } CFRelease(trust); } return result; } TlsContext* TlsContext::create() { return new TlsContext_osx(); } #elif defined (_WIN32) #include #include #include #include #include #define SECURITY_WIN32 #define NOMINMAX #include #include #include #undef SECURITY_WIN32 #undef NOMINMAX PSecurityFunctionTable getSecurityLibrary(); class TlsContext_sspi : public TlsContext { public: TlsContext_sspi(); ~TlsContext_sspi(); TlsSession* createSession() override; bool setCertificate(const char* certificate) override; bool setPrivateKey(const char* privateKey) override; bool addTrustedCertificate(const char* certificate) { return false; } PCCERT_CONTEXT& getCertificate() { return _certificate; } private: static const char* kContainerName; void importKeyAndCertificate(); HCRYPTPROV _provider = 0; HCERTSTORE _store = INVALID_HANDLE_VALUE; PCCERT_CONTEXT _certificate = 0; bool _dirty = true; }; const char* TlsContext_sspi::kContainerName = "_tmp0"; class TlsSession_sspi : public TlsSession { public: TlsSession_sspi(TlsContext_sspi* context); ~TlsSession_sspi(); void startConnect() override; void startAccept() override; int getPeerCertificate(char* buffer, size_t size) override; void shutdown() override; HandshakeResult handshake() override; int readPlain(char* buffer, size_t bytes) override; int writePlain(const char* buffer, size_t bytes) override; int readEncrypted(char* buffer, size_t bytes) override; int writeEncrypted(const char* buffer, size_t bytes) override; void setHostname(const char* hostname) { _hostname = hostname; } static void setError(HRESULT error); bool getError(char* buffer, size_t bytes) override; private: HandshakeResult handshakeInternal(bool initial); TlsContext_sspi* _context; CredHandle _credentialsHandle; CtxtHandle _securityContext; SecPkgContext_StreamSizes _sizes; enum { k_direction_undetermined, k_direction_connect, k_direction_accept } _direction = k_direction_undetermined; bool _initial = false; static HRESULT _lastError; std::vector _inBuffer; std::vector _outBuffer; std::vector _decryptedBuffer; std::string _hostname; }; HRESULT TlsSession_sspi::_lastError = S_OK; PSecurityFunctionTable getSecurityLibrary() { PSecurityFunctionTable (*table)(); static PSecurityFunctionTable security; static bool loaded; if (!loaded) { HMODULE module = LoadLibrary("security.dll"); table = (PSecurityFunctionTable(*)())GetProcAddress(module, "InitSecurityInterfaceA"); assert(table && "failed to load security.dll"); security = table(); assert(security && "No function table in security.dll"); loaded = true; } return security; } TlsContext_sspi::TlsContext_sspi() { if (!CryptAcquireContext(&_provider, kContainerName, MS_DEF_RSA_SCHANNEL_PROV, PROV_RSA_SCHANNEL, CRYPT_NEWKEYSET)) { if (GetLastError() != NTE_EXISTS || !CryptAcquireContext(&_provider, kContainerName, MS_DEF_RSA_SCHANNEL_PROV, PROV_RSA_SCHANNEL, 0)) { TlsSession_sspi::setError(GetLastError()); } } _store = CertOpenStore(CERT_STORE_PROV_SYSTEM, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, _provider, CERT_SYSTEM_STORE_LOCAL_MACHINE | CERT_STORE_NO_CRYPT_RELEASE_FLAG | CERT_STORE_OPEN_EXISTING_FLAG, L"MY"); } TlsContext_sspi::~TlsContext_sspi() { if (_store) { CertCloseStore(_store, 0); _store = INVALID_HANDLE_VALUE; } if (_provider) { CryptReleaseContext(_provider, 0); _provider = 0; } if (_certificate) { CertFreeCertificateContext(_certificate); _certificate = 0; } } bool TlsContext_sspi::setCertificate(const char* certificate) { if (_certificate) { CertFreeCertificateContext(_certificate); _certificate = 0; } std::vector certificateBuffer; DWORD size = 0; if (CryptStringToBinary(certificate, 0, CRYPT_STRING_BASE64HEADER, 0, &size, 0, 0)) { certificateBuffer.resize(size); if (!CryptStringToBinary(certificate, 0, CRYPT_STRING_BASE64HEADER, certificateBuffer.data(), &size, 0, 0)) { certificateBuffer.resize(0); } } if (certificateBuffer.size()) { _certificate = CertCreateCertificateContext(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, certificateBuffer.data(), certificateBuffer.size()); } if (!CertAddCertificateContextToStore(_store, _certificate, CERT_STORE_ADD_REPLACE_EXISTING, 0)) { TlsSession_sspi::setError(GetLastError()); } _dirty = true; return _certificate != 0; } bool TlsContext_sspi::setPrivateKey(const char* privateKey) { bool result = false; std::vector keyBuffer; std::vector keyBlob; DWORD size = 0; if (CryptStringToBinary(privateKey, 0, CRYPT_STRING_BASE64HEADER, 0, &size, 0, 0)) { keyBuffer.resize(size); if (!CryptStringToBinary(privateKey, 0, CRYPT_STRING_BASE64HEADER, keyBuffer.data(), &size, 0, 0)) { TlsSession_sspi::setError(GetLastError()); keyBuffer.resize(0); } } size = 0; if (CryptDecodeObjectEx(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, PKCS_RSA_PRIVATE_KEY, keyBuffer.data(), keyBuffer.size(), 0, 0, 0, &size)) { keyBlob.resize(size); if (!CryptDecodeObjectEx(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, PKCS_RSA_PRIVATE_KEY, keyBuffer.data(), keyBuffer.size(), 0, 0, keyBlob.data(), &size)) { TlsSession_sspi::setError(GetLastError()); keyBlob.resize(0); } } HCRYPTKEY cryptKey = 0; if (CryptImportKey(_provider, keyBlob.data(), keyBlob.size(), 0, 0, &cryptKey)) { CryptDestroyKey(cryptKey); cryptKey = 0; result = true; } else { TlsSession_sspi::setError(GetLastError()); } _dirty = true; return result; } void TlsContext_sspi::importKeyAndCertificate() { if (_certificate) { WCHAR wname[32]; mbstowcs(wname, kContainerName, sizeof(kContainerName) + 1); CRYPT_KEY_PROV_INFO info; ZeroMemory(&info, sizeof(info)); info.pwszContainerName = wname; info.pwszProvName = MS_DEF_RSA_SCHANNEL_PROV_W; info.dwProvType = PROV_RSA_SCHANNEL; info.dwKeySpec = AT_KEYEXCHANGE; if (!CertSetCertificateContextProperty(_certificate, CERT_KEY_PROV_INFO_PROP_ID, 0, reinterpret_cast(&info))) { TlsSession_sspi::setError(GetLastError()); } } _dirty = false; } TlsSession* TlsContext_sspi::createSession() { if (_dirty) { importKeyAndCertificate(); } return new TlsSession_sspi(this); } TlsSession_sspi::TlsSession_sspi(TlsContext_sspi* context) { _context = context; ZeroMemory(&_credentialsHandle, sizeof(_credentialsHandle)); ZeroMemory(&_securityContext, sizeof(_securityContext)); ZeroMemory(&_sizes, sizeof(_sizes)); } TlsSession_sspi::~TlsSession_sspi() { getSecurityLibrary()->FreeCredentialsHandle(&_credentialsHandle); getSecurityLibrary()->DeleteSecurityContext(&_securityContext); } void TlsSession_sspi::startAccept() { _direction = k_direction_accept; _initial = true; SCHANNEL_CRED credentials; ZeroMemory(&credentials, sizeof(credentials)); credentials.dwVersion = SCHANNEL_CRED_VERSION; credentials.cCreds = 1; credentials.paCred = &_context->getCertificate(); _lastError = getSecurityLibrary()->AcquireCredentialsHandleA(0, UNISP_NAME_A, SECPKG_CRED_INBOUND, 0, &credentials, 0, 0, &_credentialsHandle, 0); if (_lastError == S_OK) { handshakeInternal(true); } } void TlsSession_sspi::startConnect() { _direction = k_direction_connect; _initial = true; SCHANNEL_CRED credentials; ZeroMemory(&credentials, sizeof(credentials)); credentials.dwVersion = SCHANNEL_CRED_VERSION; credentials.dwFlags = SCH_CRED_NO_DEFAULT_CREDS; _lastError = getSecurityLibrary()->AcquireCredentialsHandleA(0, UNISP_NAME_A, SECPKG_CRED_OUTBOUND, 0, &credentials, 0, 0, &_credentialsHandle, 0); if (_lastError == S_OK) { handshakeInternal(true); } } int TlsSession_sspi::getPeerCertificate(char* buffer, size_t size) { int result = -1; PCCERT_CONTEXT certificate; HRESULT status = getSecurityLibrary()->QueryContextAttributesA(&_securityContext, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &certificate); if (FAILED(status)) { _lastError = status; } else { DWORD bufferSize = size; if (CryptBinaryToString(certificate->pbCertEncoded, certificate->cbCertEncoded, CRYPT_STRING_BASE64HEADER, buffer, &bufferSize)) { return bufferSize; } } return result; } void TlsSession_sspi::shutdown() { DWORD type = SCHANNEL_SHUTDOWN; SecBufferDesc bufferDesc; SecBuffer buffers[1]; buffers[0].pvBuffer = &type; buffers[0].BufferType = SECBUFFER_TOKEN; buffers[0].cbBuffer = sizeof(type); bufferDesc.cBuffers = 1; bufferDesc.pBuffers = buffers; bufferDesc.ulVersion = SECBUFFER_TOKEN; SECURITY_STATUS status = getSecurityLibrary()->ApplyControlToken(&_securityContext, &bufferDesc); if (!FAILED(status)) { buffers[0].pvBuffer = 0; buffers[0].BufferType = SECBUFFER_TOKEN; buffers[0].cbBuffer = 0; bufferDesc.cBuffers = 1; bufferDesc.pBuffers = buffers; bufferDesc.ulVersion = SECBUFFER_VERSION; DWORD outFlags = 0; status = getSecurityLibrary()->InitializeSecurityContextA( &_credentialsHandle, &_securityContext, 0, ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | ISC_RET_EXTENDED_ERROR | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM, 0, 0, 0, 0, &_securityContext, &bufferDesc, &outFlags, 0); if (!FAILED(status) && buffers[0].pvBuffer && buffers[0].cbBuffer) { const char* data = reinterpret_cast(buffers[0].pvBuffer); _outBuffer.insert(_outBuffer.end(), data, data + buffers[0].cbBuffer); } } } TlsSession::HandshakeResult TlsSession_sspi::handshake() { return handshakeInternal(_initial); } TlsSession::HandshakeResult TlsSession_sspi::handshakeInternal(bool initial) { SecBufferDesc outBuffer; SecBuffer outBuffers[1]; SecBufferDesc inBuffer; SecBuffer inBuffers[2]; DWORD outFlags = 0; outBuffers[0].pvBuffer = 0; outBuffers[0].BufferType = SECBUFFER_TOKEN; outBuffers[0].cbBuffer = 0; outBuffer.cBuffers = 1; outBuffer.pBuffers = outBuffers; outBuffer.ulVersion = SECBUFFER_VERSION; std::vector buffer(_inBuffer); inBuffers[0].pvBuffer = buffer.data(); inBuffers[0].cbBuffer = buffer.size(); inBuffers[0].BufferType = SECBUFFER_TOKEN; inBuffers[1].pvBuffer = 0; inBuffers[1].cbBuffer = 0; inBuffers[1].BufferType = SECBUFFER_EMPTY; inBuffer.cBuffers = 2; inBuffer.pBuffers = inBuffers; inBuffer.ulVersion = SECBUFFER_VERSION; SECURITY_STATUS status = SEC_E_OK; if (_direction == k_direction_connect) { status = getSecurityLibrary()->InitializeSecurityContextA( &_credentialsHandle, initial ? 0 : &_securityContext, _hostname.size() ? const_cast(_hostname.c_str()) : 0, ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | ISC_RET_EXTENDED_ERROR | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM, 0, 0, &inBuffer, 0, &_securityContext, &outBuffer, &outFlags, 0); } else if (_direction = k_direction_accept) { status = getSecurityLibrary()->AcceptSecurityContext( &_credentialsHandle, initial ? 0 : &_securityContext, &inBuffer, ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY | ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM, 0, &_securityContext, &outBuffer, &outFlags, 0); } if (!FAILED(status)) { _initial = false; } TlsSession::HandshakeResult result = TlsSession::kFailed; size_t extra = 0; for (int i = 0; i < inBuffer.cBuffers; ++i) { if (inBuffers[i].BufferType == SECBUFFER_EXTRA && inBuffers[i].cbBuffer) { extra += inBuffers[i].cbBuffer; } } size_t missing = 0; for (int i = 0; i < inBuffer.cBuffers; ++i) { if (inBuffers[i].BufferType == SECBUFFER_MISSING && inBuffers[i].cbBuffer) { missing += inBuffers[i].cbBuffer; } } if (outBuffers[0].cbBuffer && outBuffers[0].pvBuffer) { const char* data = reinterpret_cast(outBuffers[0].pvBuffer); _outBuffer.insert(_outBuffer.end(), data, data + outBuffers[0].cbBuffer); getSecurityLibrary()->FreeContextBuffer(outBuffers[0].pvBuffer); } if (status == SEC_E_OK) { result = TlsSession::kDone; } else if (status == SEC_E_INCOMPLETE_MESSAGE || status == SEC_I_CONTINUE_NEEDED) { result = TlsSession::kMore; } else if (FAILED(status)) { result = TlsSession::kFailed; } _inBuffer.erase(_inBuffer.begin(), _inBuffer.end() - extra); if (result == TlsSession::kDone) { status = getSecurityLibrary()->QueryContextAttributesA(&_securityContext, SECPKG_ATTR_STREAM_SIZES, &_sizes); if (FAILED(status)) { result = TlsSession::kFailed; } } return result; } int TlsSession_sspi::readPlain(char* buffer, size_t bytes) { int result = TlsSession::kReadFailed; if (bytes <= _decryptedBuffer.size()) { std::memcpy(buffer, _decryptedBuffer.data(), bytes); _decryptedBuffer.erase(_decryptedBuffer.begin(), _decryptedBuffer.begin() + bytes); result = bytes; } else if (_inBuffer.size()) { SecBufferDesc bufferDesc; SecBuffer buffers[4]; std::vector data(_inBuffer); buffers[0].pvBuffer = data.data(); buffers[0].cbBuffer = data.size(); buffers[0].BufferType = SECBUFFER_DATA; buffers[1].BufferType = SECBUFFER_EMPTY; buffers[2].BufferType = SECBUFFER_EMPTY; buffers[3].BufferType = SECBUFFER_EMPTY; bufferDesc.ulVersion = SECBUFFER_VERSION; bufferDesc.cBuffers = 4; bufferDesc.pBuffers = buffers; SECURITY_STATUS status = getSecurityLibrary()->DecryptMessage(&_securityContext, &bufferDesc, 0, 0); if (status == SEC_I_CONTEXT_EXPIRED) { _inBuffer.clear(); result = TlsSession::kReadZero; } else if (status == SEC_E_INCOMPLETE_MESSAGE) { result = 0; } else if (status == SEC_E_OK) { result = 0; size_t extra = 0; for (int i = 0; i < bufferDesc.cBuffers; ++i) { if (buffers[i].BufferType == SECBUFFER_DATA) { const char* decrypted = reinterpret_cast(buffers[i].pvBuffer); _decryptedBuffer.insert(_decryptedBuffer.end(), decrypted, decrypted + buffers[i].cbBuffer); } else if (buffers[i].BufferType == SECBUFFER_EXTRA) { extra += buffers[i].cbBuffer; } } _inBuffer.erase(_inBuffer.begin(), _inBuffer.end() - extra); size_t actual = std::min(_decryptedBuffer.size(), bytes); if (actual > 0) { std::memcpy(buffer, _decryptedBuffer.data(), actual); _decryptedBuffer.erase(_decryptedBuffer.begin(), _decryptedBuffer.begin() + actual); result = actual; } } else { _inBuffer.clear(); result = TlsSession::kReadFailed; } } else { size_t actual = std::min(_decryptedBuffer.size(), bytes); if (actual > 0) { std::memcpy(buffer, _decryptedBuffer.data(), actual); _decryptedBuffer.erase(_decryptedBuffer.begin(), _decryptedBuffer.begin() + actual); result = actual; } else { result = 0; } } return result; } int TlsSession_sspi::writePlain(const char* buffer, size_t bytes) { SecBufferDesc bufferDesc; SecBuffer buffers[4]; std::vector data(_sizes.cbHeader + _sizes.cbTrailer + bytes); std::memcpy(data.data() + _sizes.cbHeader, buffer, bytes); buffers[0].pvBuffer = data.data(); buffers[0].cbBuffer = _sizes.cbHeader; buffers[0].BufferType = SECBUFFER_STREAM_HEADER; buffers[1].pvBuffer = data.data() + _sizes.cbHeader; buffers[1].cbBuffer = bytes; buffers[1].BufferType = SECBUFFER_DATA; buffers[2].pvBuffer = data.data() + _sizes.cbHeader + bytes; buffers[2].cbBuffer = _sizes.cbTrailer; buffers[2].BufferType = SECBUFFER_STREAM_TRAILER; buffers[3].BufferType = SECBUFFER_EMPTY; bufferDesc.ulVersion = SECBUFFER_VERSION; bufferDesc.cBuffers = 4; bufferDesc.pBuffers = buffers; SECURITY_STATUS status = getSecurityLibrary()->EncryptMessage(&_securityContext, 0, &bufferDesc, 0); for (int i = 0; i < bufferDesc.cBuffers; ++i) { if (buffers[i].BufferType != SECBUFFER_EMPTY && buffers[i].pvBuffer && buffers[i].cbBuffer) { const char* bufferData = reinterpret_cast(buffers[i].pvBuffer); _outBuffer.insert(_outBuffer.end(), bufferData, bufferData + buffers[i].cbBuffer); } } return 0; } int TlsSession_sspi::readEncrypted(char* buffer, size_t bytes) { size_t size = std::min(bytes, _outBuffer.size()); if (size > 0) { std::memcpy(buffer, _outBuffer.data(), size); _outBuffer.erase(_outBuffer.begin(), _outBuffer.begin() + size); } return size; } int TlsSession_sspi::writeEncrypted(const char* buffer, size_t bytes) { _inBuffer.insert(_inBuffer.end(), buffer, buffer + bytes); return bytes; } void TlsSession_sspi::setError(HRESULT error) { _lastError = error; } bool TlsSession_sspi::getError(char* buffer, size_t bytes) { bool result = false; if (_lastError != S_OK) { DWORD length = FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, 0, _lastError, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)buffer, bytes, 0); _lastError = S_OK; if (length > 0) { result = true; } } return result; } TlsContext* TlsContext::create() { return new TlsContext_sspi(); } #else TlsContext* TlsContext::create() { return 0; } #endif tf_tls_context_t* tf_tls_context_create() { tf_tls_context_t* context = malloc(sizeof(tf_tls_context_t)); memset(context, 0, sizeof(*context)); #if defined(TF_TLS_OPENSSL) SSL_library_init(); SSL_load_error_strings(); context->context = SSL_CTX_new(SSLv23_method()); SSL_CTX_set_default_verify_paths(context->context); #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif return context; } bool tf_tls_context_set_certificate(tf_tls_context_t* context, const char* certificate) { #if defined(TF_TLS_OPENSSL) int result = 0; BIO* bio = BIO_new(BIO_s_mem()); BIO_puts(bio, certificate); X509* x509 = PEM_read_bio_X509(bio, 0, 0, 0); result = SSL_CTX_use_certificate(context->context, x509); while (true) { x509 = PEM_read_bio_X509(bio, 0, 0, 0); if (x509) { SSL_CTX_add_extra_chain_cert(context->context, x509); } else { break; } } BIO_free(bio); return result == 1; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } bool tf_tls_context_set_private_key(tf_tls_context_t* context, const char* private_key) { #if defined(TF_TLS_OPENSSL) int result = 0; BIO* bio = BIO_new(BIO_s_mem()); BIO_puts(bio, private_key); EVP_PKEY* key = PEM_read_bio_PrivateKey(bio, 0, 0, 0); result = SSL_CTX_use_PrivateKey(context->context, key); BIO_free(bio); return result == 1; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } bool tf_tls_context_add_trusted_certificate(tf_tls_context_t* context, const char* certificate) { #if defined(TF_TLS_OPENSSL) bool result = false; BIO* bio = BIO_new_mem_buf(certificate, -1); X509* x509 = PEM_read_bio_X509(bio, 0, 0, 0); BIO_free(bio); if (x509) { X509_STORE* store = SSL_CTX_get_cert_store(context->context); if (store && X509_STORE_add_cert(store, x509) == 1) { result = true; } X509_free(x509); } return result; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } tf_tls_session_t* tf_tls_context_create_session(tf_tls_context_t* context) { #if defined(TF_TLS_OPENSSL) tf_tls_session_t* session = malloc(sizeof(tf_tls_session_t)); memset(session, 0, sizeof(*session)); session->context = context; session->bio_in = BIO_new(BIO_s_mem()); session->bio_out = BIO_new(BIO_s_mem()); return session; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } void tf_tls_context_destroy(tf_tls_context_t* context) { #if defined(TF_TLS_OPENSSL) SSL_CTX_free(context->context); free(context); #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } void tf_tls_session_destroy(tf_tls_session_t* session) { #if defined(TF_TLS_OPENSSL) if (session->ssl) { SSL_free(session->ssl); } if (session->hostname) { free((void*)session->hostname); } free(session); #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } void tf_tls_session_set_hostname(tf_tls_session_t* session, const char* hostname) { #if defined(TF_TLS_OPENSSL) if (session->hostname) { free((void*)session->hostname); session->hostname = NULL; } if (hostname) { session->hostname = strdup(hostname); } #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } void tf_tls_session_start_accept(tf_tls_session_t* session) { #if defined(TF_TLS_OPENSSL) session->direction = k_direction_accept; session->ssl = SSL_new(session->context->context); SSL_set_bio(session->ssl, session->bio_in, session->bio_out); SSL_accept(session->ssl); tf_tls_session_handshake(session); #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } void tf_tls_session_start_connect(tf_tls_session_t* session) { #if defined(TF_TLS_OPENSSL) session->direction = k_direction_connect; session->ssl = SSL_new(session->context->context); X509_VERIFY_PARAM* param = SSL_get0_param(session->ssl); X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); X509_VERIFY_PARAM_set1_host(param, session->hostname, 0); SSL_set_bio(session->ssl, session->bio_in, session->bio_out); SSL_connect(session->ssl); tf_tls_session_handshake(session); #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } void tf_tls_session_shutdown(tf_tls_session_t* session) { #if defined(TF_TLS_OPENSSL) SSL_shutdown(session->ssl); #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } int tf_tls_session_get_peer_certificate(tf_tls_session_t* session, char* buffer, size_t bytes) { #if defined(TF_TLS_OPENSSL) int result = -1; X509* certificate = SSL_get_peer_certificate(session->ssl); BIO* bio = BIO_new(BIO_s_mem()); PEM_write_bio_X509(bio, certificate); BUF_MEM* mem; BIO_get_mem_ptr(bio, &mem); if (mem->length <= bytes) { memcpy(buffer, mem->data, mem->length); result = mem->length; } BIO_free(bio); return result; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } #if defined(TF_TLS_OPENSSL) #if OPENSSL_VERSION_NUMBER < 0x10100000L bool _tls_session_wildcard_match(const char* pattern, size_t pattern_length, const char* name) { const char* it = pattern; while (it - pattern < pattern_length && *name) { if (*it == '*') { for (const char* p = name; *p; ++p) { if (_tls_session_wildcard_match(it + 1, pattern_length - 1, p)) { return true; } } return false; } else if (tolower(*it) == tolower(*name)) { ++it; ++name; } else { break; } } return it - pattern <= pattern_length && *name == 0; } static bool _tls_session_verify_hostname(X509* certificate, const char* hostname) { bool verified = false; void* names = X509_get_ext_d2i(certificate, NID_subject_alt_name, 0, 0); if (names) { int count = sk_GENERAL_NAME_num(names); for (int i = 0; i < count; ++i) { const GENERAL_NAME* check = sk_GENERAL_NAME_value(names, i); if (!verified) { #if OPENSSL_VERSION_NUMBER <= 0x1000211fL const unsigned char* name = ASN1_STRING_data(check->d.ia5); #else const char* name = ASN1_STRING_get0_data(check->d.ia5); #endif size_t length = ASN1_STRING_length(check->d.ia5); if (_tls_session_wildcard_match((const char*)name, length, hostname)) { verified = true; } } } sk_GENERAL_NAMES_free(names); } if (!verified) { int index = X509_NAME_get_index_by_NID(X509_get_subject_name(certificate), NID_commonName, -1); if (index >= 0) { X509_NAME_ENTRY* entry = X509_NAME_get_entry(X509_get_subject_name(certificate), index); if (entry) { ASN1_STRING* asn1 = X509_NAME_ENTRY_get_data(entry); if (asn1) { #if OPENSSL_VERSION_NUMBER <= 0x1000211fL const unsigned char* commonName = ASN1_STRING_data(asn1); #else const char* commonName = ASN1_STRING_get0_data(asn1); #endif if ((size_t)(ASN1_STRING_length(asn1)) == strlen((const char*)commonName)) { verified = _tls_session_wildcard_match((const char*)commonName, ASN1_STRING_length(asn1), hostname); } } } } } return verified; } #endif static bool _tls_session_verify_peer_certificate(tf_tls_session_t* session) { bool verified = false; X509* certificate = SSL_get_peer_certificate(session->ssl); if (certificate) { if (SSL_get_verify_result(session->ssl) == X509_V_OK) { #if OPENSSL_VERSION_NUMBER < 0x10100000L if (_tls_session_verify_hostname(certificate, session->hostname)) { verified = true; } #else verified = true; #endif } X509_free(certificate); } return verified; } #endif tf_tls_handshake_t tf_tls_session_handshake(tf_tls_session_t* session) { #if defined(TF_TLS_OPENSSL) tf_tls_handshake_t result = k_tls_handshake_done; if (!SSL_is_init_finished(session->ssl)) { int value = SSL_do_handshake(session->ssl); if (value <= 0) { int error = SSL_get_error(session->ssl, value); if (error != SSL_ERROR_WANT_READ && error != SSL_ERROR_WANT_WRITE) { result = k_tls_handshake_failed; } else { result = k_tls_handshake_more; } } } if (result == k_tls_handshake_done && session->direction == k_direction_connect && !_tls_session_verify_peer_certificate(session)) { result = k_tls_handshake_failed; } return result; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } int tf_tls_session_read_plain(tf_tls_session_t* session, char* buffer, size_t bytes) { #if defined(TF_TLS_OPENSSL) int result = SSL_read(session->ssl, buffer, bytes); if (result <= 0) { int error = SSL_get_error(session->ssl, result); if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) { result = 0; } else if (error == SSL_ERROR_ZERO_RETURN) { if ((SSL_get_shutdown(session->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) { result = k_tls_read_zero; } else { result = 0; } } else { result = k_tls_read_failed; } } return result; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } int tf_tls_session_write_plain(tf_tls_session_t* session, const char* buffer, size_t bytes) { #if defined(TF_TLS_OPENSSL) return SSL_write(session->ssl, buffer, bytes); #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } int tf_tls_session_read_encrypted(tf_tls_session_t* session, char* buffer, size_t bytes) { #if defined(TF_TLS_OPENSSL) return BIO_read(session->bio_out, buffer, bytes); #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } int tf_tls_session_write_encrypted(tf_tls_session_t* session, const char* buffer, size_t bytes) { #if defined(TF_TLS_OPENSSL) return BIO_write(session->bio_in, buffer, bytes); #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif } bool tf_tls_session_get_error(tf_tls_session_t* session, char* buffer, size_t bytes) { #if defined(TF_TLS_OPENSSL) unsigned long error = ERR_get_error(); if (error != 0) { ERR_error_string_n(error, buffer, bytes); } return error != 0; #elif defined(TF_TLS_APPLE) #elif defined(TF_TLS_SCHANNEL) #endif }