diff --git a/src/Socket.cpp b/src/Socket.cpp index 41b79e4a..a5891ebc 100644 --- a/src/Socket.cpp +++ b/src/Socket.cpp @@ -483,32 +483,60 @@ void Socket::notifyDataRead(const char* data, size_t length) { } } +int Socket::writeBytes(std::function callback, v8::Handle value, int* outLength) { + int result = -1; + + if (value->IsArrayBufferView()) { + v8::Handle array = v8::Handle::Cast(value); + result = callback(reinterpret_cast(array->Buffer()->GetContents().Data()), array->Buffer()->GetContents().ByteLength()); + if (outLength) { + *outLength = array->Buffer()->GetContents().ByteLength(); + } + } else if (value->IsString()) { + v8::Handle stringValue = v8::Handle::Cast(value); + if (stringValue->ContainsOnlyOneByte()) { + std::vector bytes(stringValue->Length()); + stringValue->WriteOneByte(bytes.data(), 0, bytes.size(), v8::String::NO_NULL_TERMINATION); + result = callback(reinterpret_cast(bytes.data()), bytes.size()); + if (outLength) { + *outLength = stringValue->Length(); + } + } else { + v8::String::Utf8Value utf8(stringValue); + result = callback(*utf8, utf8.length()); + if (outLength) { + *outLength = utf8.length(); + } + } + } + + return result; +} + +int Socket::writeInternal(promiseid_t promise, const char* data, size_t length) { + char* rawBuffer = new char[sizeof(uv_write_t) + length]; + uv_write_t* request = reinterpret_cast(rawBuffer); + std::memcpy(rawBuffer + sizeof(uv_write_t), data, length); + + uv_buf_t buffer; + buffer.base = rawBuffer + sizeof(uv_write_t); + buffer.len = length; + + request->data = reinterpret_cast(promise); + return uv_write(request, reinterpret_cast(&_socket), &buffer, 1, onWrite); +} + void Socket::write(const v8::FunctionCallbackInfo& args) { if (Socket* socket = Socket::get(args.Data())) { promiseid_t promise = socket->_task->allocatePromise(); args.GetReturnValue().Set(socket->_task->getPromise(promise)); v8::Handle value = args[0]; - if (!value.IsEmpty() && (value->IsString() || value->IsUint8Array())) { + if (!value.IsEmpty()) { if (socket->_tls) { socket->reportTlsErrors(); - int result; - int length; - if (value->IsArrayBufferView()) { - v8::Handle array = v8::Handle::Cast(value); - result = socket->_tls->writePlain(reinterpret_cast(array->Buffer()->GetContents().Data()), array->Buffer()->GetContents().ByteLength()); - } else if (value->IsString()) { - v8::Handle stringValue = v8::Handle::Cast(value); - if (stringValue->ContainsOnlyOneByte()) { - length = stringValue->Length(); - std::vector bytes(length); - stringValue->WriteOneByte(bytes.data(), 0, bytes.size(), v8::String::NO_NULL_TERMINATION); - result = socket->_tls->writePlain(reinterpret_cast(bytes.data()), bytes.size()); - } else { - v8::String::Utf8Value utf8(stringValue); - length = utf8.length(); - result = socket->_tls->writePlain(*utf8, utf8.length()); - } - } + int length = 0; + std::function writeFunction = std::bind(&TlsSession::writePlain, socket->_tls, std::placeholders::_1, std::placeholders::_2); + int result = socket->writeBytes(writeFunction, value, &length); char buffer[8192]; if (result <= 0 && socket->_tls->getError(buffer, sizeof(buffer))) { socket->_task->rejectPromise(promise, v8::String::NewFromUtf8(args.GetIsolate(), buffer)); @@ -520,34 +548,10 @@ void Socket::write(const v8::FunctionCallbackInfo& args) { socket->processOutgoingTls(); } else { int length; - char* rawBuffer = 0; - v8::Handle stringValue = v8::Handle::Cast(value); - if (stringValue->IsArrayBufferView()) { - v8::Handle array = v8::Handle::Cast(value); - length = array->Buffer()->GetContents().ByteLength(); - rawBuffer = new char[sizeof(uv_write_t) + length]; - std::memcpy(rawBuffer + sizeof(uv_write_t), array->Buffer()->GetContents().Data(), length); - } else if (value->IsString()) { - v8::String::Utf8Value utf8(stringValue); - if (stringValue->ContainsOnlyOneByte()) { - length = stringValue->Length(); - rawBuffer = new char[sizeof(uv_write_t) + length]; - stringValue->WriteOneByte(reinterpret_cast(rawBuffer) + sizeof(uv_write_t), 0, length, v8::String::NO_NULL_TERMINATION); - } else { - v8::String::Utf8Value utf8(stringValue); - length = utf8.length(); - rawBuffer = new char[sizeof(uv_write_t) + length]; - std::memcpy(rawBuffer + sizeof(uv_write_t), *utf8, length); - } - } - uv_write_t* request = reinterpret_cast(rawBuffer); - uv_buf_t buffer; - buffer.base = rawBuffer + sizeof(uv_write_t); - buffer.len = length; + std::function writeFunction = std::bind(&Socket::writeInternal, socket, promise, std::placeholders::_1, std::placeholders::_2); + int result = socket->writeBytes(writeFunction, value, &length); - request->data = reinterpret_cast(promise); - int result = uv_write(request, reinterpret_cast(&socket->_socket), &buffer, 1, onWrite); if (result != 0) { std::string error = "uv_write: " + std::string(uv_strerror(result)); socket->_task->rejectPromise(promise, v8::String::NewFromUtf8(args.GetIsolate(), error.c_str(), v8::String::kNormalString, error.size())); diff --git a/src/Socket.h b/src/Socket.h index a5759487..f7dc37d1 100644 --- a/src/Socket.h +++ b/src/Socket.h @@ -75,6 +75,8 @@ private: static void onRelease(const v8::WeakCallbackInfo& data); void notifyDataRead(const char* data, size_t length); + int writeBytes(std::function callback, v8::Handle value, int* outLength); + int writeInternal(promiseid_t promise, const char* data, size_t length); void processTlsShutdown(promiseid_t promise); static void onTlsShutdown(uv_write_t* request, int status); void shutdownInternal(promiseid_t promise);