mirror of
				https://github.com/yuzu-emu/yuzu-android
				synced 2025-10-25 14:02:26 -07:00 
			
		
		
		
	Updates:
- Address PR feedback. - Add SecureTransport backend for macOS.
This commit is contained in:
		| @@ -443,15 +443,28 @@ void BSD::Close(HLERequestContext& ctx) { | ||||
| } | ||||
|  | ||||
| void BSD::DuplicateSocket(HLERequestContext& ctx) { | ||||
|     IPC::RequestParser rp{ctx}; | ||||
|     const s32 fd = rp.Pop<s32>(); | ||||
|     [[maybe_unused]] const u64 unused = rp.Pop<u64>(); | ||||
|     struct InputParameters { | ||||
|         s32 fd; | ||||
|         u64 reserved; | ||||
|     }; | ||||
|     static_assert(sizeof(InputParameters) == 0x10); | ||||
|  | ||||
|     Expected<s32, Errno> res = DuplicateSocketImpl(fd); | ||||
|     struct OutputParameters { | ||||
|         s32 ret; | ||||
|         Errno bsd_errno; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0x8); | ||||
|  | ||||
|     IPC::RequestParser rp{ctx}; | ||||
|     auto input = rp.PopRaw<InputParameters>(); | ||||
|  | ||||
|     Expected<s32, Errno> res = DuplicateSocketImpl(input.fd); | ||||
|     IPC::ResponseBuilder rb{ctx, 4}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push(res.value_or(0));                         // ret | ||||
|     rb.Push(res ? 0 : static_cast<s32>(res.error())); // bsd errno | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .ret = res.value_or(0), | ||||
|         .bsd_errno = res ? Errno::SUCCESS : res.error(), | ||||
|     }); | ||||
| } | ||||
|  | ||||
| void BSD::EventFd(HLERequestContext& ctx) { | ||||
|   | ||||
| @@ -131,14 +131,15 @@ static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::Add | ||||
| } | ||||
|  | ||||
| static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) { | ||||
|     struct Parameters { | ||||
|     struct InputParameters { | ||||
|         u8 use_nsd_resolve; | ||||
|         u32 cancel_handle; | ||||
|         u64 process_id; | ||||
|     }; | ||||
|     static_assert(sizeof(InputParameters) == 0x10); | ||||
|  | ||||
|     IPC::RequestParser rp{ctx}; | ||||
|     const auto parameters = rp.PopRaw<Parameters>(); | ||||
|     const auto parameters = rp.PopRaw<InputParameters>(); | ||||
|  | ||||
|     LOG_WARNING( | ||||
|         Service, | ||||
| @@ -164,21 +165,39 @@ static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestConte | ||||
| void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) { | ||||
|     auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); | ||||
|  | ||||
|     struct OutputParameters { | ||||
|         NetDbError netdb_error; | ||||
|         Errno bsd_errno; | ||||
|         u32 data_size; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0xc); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 5}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push(static_cast<s32>(GetAddrInfoErrorToNetDbError(emu_gai_err))); // netdb error code | ||||
|     rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err)));      // errno | ||||
|     rb.Push(data_size);                                                   // serialized size | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), | ||||
|         .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||||
|         .data_size = data_size, | ||||
|     }); | ||||
| } | ||||
|  | ||||
| void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) { | ||||
|     auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); | ||||
|  | ||||
|     struct OutputParameters { | ||||
|         u32 data_size; | ||||
|         NetDbError netdb_error; | ||||
|         Errno bsd_errno; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0xc); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 5}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push(data_size);                                                   // serialized size | ||||
|     rb.Push(static_cast<s32>(GetAddrInfoErrorToNetDbError(emu_gai_err))); // netdb error code | ||||
|     rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err)));      // errno | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .data_size = data_size, | ||||
|         .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), | ||||
|         .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||||
|     }); | ||||
| } | ||||
|  | ||||
| static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec, | ||||
| @@ -221,14 +240,15 @@ static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& v | ||||
| } | ||||
|  | ||||
| static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) { | ||||
|     struct Parameters { | ||||
|     struct InputParameters { | ||||
|         u8 use_nsd_resolve; | ||||
|         u32 cancel_handle; | ||||
|         u64 process_id; | ||||
|     }; | ||||
|     static_assert(sizeof(InputParameters) == 0x10); | ||||
|  | ||||
|     IPC::RequestParser rp{ctx}; | ||||
|     const auto parameters = rp.PopRaw<Parameters>(); | ||||
|     const auto parameters = rp.PopRaw<InputParameters>(); | ||||
|  | ||||
|     LOG_WARNING( | ||||
|         Service, | ||||
| @@ -264,23 +284,42 @@ static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext | ||||
| void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) { | ||||
|     auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); | ||||
|  | ||||
|     struct OutputParameters { | ||||
|         Errno bsd_errno; | ||||
|         GetAddrInfoError gai_error; | ||||
|         u32 data_size; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0xc); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 5}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err))); // errno | ||||
|     rb.Push(static_cast<s32>(emu_gai_err));                          // getaddrinfo error code | ||||
|     rb.Push(data_size);                                              // serialized size | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||||
|         .gai_error = emu_gai_err, | ||||
|         .data_size = data_size, | ||||
|     }); | ||||
| } | ||||
|  | ||||
| void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) { | ||||
|     // Additional options are ignored | ||||
|     auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); | ||||
|  | ||||
|     struct OutputParameters { | ||||
|         u32 data_size; | ||||
|         GetAddrInfoError gai_error; | ||||
|         NetDbError netdb_error; | ||||
|         Errno bsd_errno; | ||||
|     }; | ||||
|     static_assert(sizeof(OutputParameters) == 0x10); | ||||
|  | ||||
|     IPC::ResponseBuilder rb{ctx, 6}; | ||||
|     rb.Push(ResultSuccess); | ||||
|     rb.Push(data_size);                                                   // serialized size | ||||
|     rb.Push(static_cast<s32>(emu_gai_err));                               // getaddrinfo error code | ||||
|     rb.Push(static_cast<s32>(GetAddrInfoErrorToNetDbError(emu_gai_err))); // netdb error code | ||||
|     rb.Push(static_cast<s32>(GetAddrInfoErrorToErrno(emu_gai_err)));      // errno | ||||
|     rb.PushRaw(OutputParameters{ | ||||
|         .data_size = data_size, | ||||
|         .gai_error = emu_gai_err, | ||||
|         .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), | ||||
|         .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), | ||||
|     }); | ||||
| } | ||||
|  | ||||
| void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) { | ||||
|   | ||||
| @@ -64,7 +64,7 @@ public: | ||||
|                             std::shared_ptr<SslContextSharedData>& shared_data, | ||||
|                             std::unique_ptr<SSLConnectionBackend>&& backend) | ||||
|         : ServiceFramework{system_, "ISslConnection"}, ssl_version{version}, | ||||
|           shared_data_{shared_data}, backend_{std::move(backend)} { | ||||
|           shared_data{shared_data}, backend{std::move(backend)} { | ||||
|         // clang-format off | ||||
|         static const FunctionInfo functions[] = { | ||||
|             {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, | ||||
| @@ -112,10 +112,10 @@ public: | ||||
|     } | ||||
|  | ||||
|     ~ISslConnection() { | ||||
|         shared_data_->connection_count--; | ||||
|         if (fd_to_close_.has_value()) { | ||||
|             const s32 fd = *fd_to_close_; | ||||
|             if (!do_not_close_socket_) { | ||||
|         shared_data->connection_count--; | ||||
|         if (fd_to_close.has_value()) { | ||||
|             const s32 fd = *fd_to_close; | ||||
|             if (!do_not_close_socket) { | ||||
|                 LOG_ERROR(Service_SSL, | ||||
|                           "do_not_close_socket was changed after setting socket; is this right?"); | ||||
|             } else { | ||||
| @@ -132,30 +132,30 @@ public: | ||||
|  | ||||
| private: | ||||
|     SslVersion ssl_version; | ||||
|     std::shared_ptr<SslContextSharedData> shared_data_; | ||||
|     std::unique_ptr<SSLConnectionBackend> backend_; | ||||
|     std::optional<int> fd_to_close_; | ||||
|     bool do_not_close_socket_ = false; | ||||
|     bool get_server_cert_chain_ = false; | ||||
|     std::shared_ptr<Network::SocketBase> socket_; | ||||
|     bool did_set_host_name_ = false; | ||||
|     bool did_handshake_ = false; | ||||
|     std::shared_ptr<SslContextSharedData> shared_data; | ||||
|     std::unique_ptr<SSLConnectionBackend> backend; | ||||
|     std::optional<int> fd_to_close; | ||||
|     bool do_not_close_socket = false; | ||||
|     bool get_server_cert_chain = false; | ||||
|     std::shared_ptr<Network::SocketBase> socket; | ||||
|     bool did_set_host_name = false; | ||||
|     bool did_handshake = false; | ||||
|  | ||||
|     ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { | ||||
|         LOG_DEBUG(Service_SSL, "called, fd={}", fd); | ||||
|         ASSERT(!did_handshake_); | ||||
|         ASSERT(!did_handshake); | ||||
|         auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); | ||||
|         ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); | ||||
|         s32 ret_fd; | ||||
|         // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor | ||||
|         if (do_not_close_socket_) { | ||||
|         if (do_not_close_socket) { | ||||
|             auto res = bsd->DuplicateSocketImpl(fd); | ||||
|             if (!res.has_value()) { | ||||
|                 LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd); | ||||
|                 return ResultInvalidSocket; | ||||
|             } | ||||
|             fd = *res; | ||||
|             fd_to_close_ = fd; | ||||
|             fd_to_close = fd; | ||||
|             ret_fd = fd; | ||||
|         } else { | ||||
|             ret_fd = -1; | ||||
| @@ -165,34 +165,34 @@ private: | ||||
|             LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); | ||||
|             return ResultInvalidSocket; | ||||
|         } | ||||
|         socket_ = std::move(*sock); | ||||
|         backend_->SetSocket(socket_); | ||||
|         socket = std::move(*sock); | ||||
|         backend->SetSocket(socket); | ||||
|         return ret_fd; | ||||
|     } | ||||
|  | ||||
|     Result SetHostNameImpl(const std::string& hostname) { | ||||
|         LOG_DEBUG(Service_SSL, "called. hostname={}", hostname); | ||||
|         ASSERT(!did_handshake_); | ||||
|         Result res = backend_->SetHostName(hostname); | ||||
|         ASSERT(!did_handshake); | ||||
|         Result res = backend->SetHostName(hostname); | ||||
|         if (res == ResultSuccess) { | ||||
|             did_set_host_name_ = true; | ||||
|             did_set_host_name = true; | ||||
|         } | ||||
|         return res; | ||||
|     } | ||||
|  | ||||
|     Result SetVerifyOptionImpl(u32 option) { | ||||
|         ASSERT(!did_handshake_); | ||||
|         ASSERT(!did_handshake); | ||||
|         LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result SetIOModeImpl(u32 _mode) { | ||||
|         auto mode = static_cast<IoMode>(_mode); | ||||
|     Result SetIoModeImpl(u32 input_mode) { | ||||
|         auto mode = static_cast<IoMode>(input_mode); | ||||
|         ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); | ||||
|         ASSERT_OR_EXECUTE(socket_, { return ResultNoSocket; }); | ||||
|         ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; }); | ||||
|  | ||||
|         const bool non_block = mode == IoMode::NonBlocking; | ||||
|         const Network::Errno error = socket_->SetNonBlock(non_block); | ||||
|         const Network::Errno error = socket->SetNonBlock(non_block); | ||||
|         if (error != Network::Errno::SUCCESS) { | ||||
|             LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); | ||||
|         } | ||||
| @@ -200,18 +200,18 @@ private: | ||||
|     } | ||||
|  | ||||
|     Result SetSessionCacheModeImpl(u32 mode) { | ||||
|         ASSERT(!did_handshake_); | ||||
|         ASSERT(!did_handshake); | ||||
|         LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result DoHandshakeImpl() { | ||||
|         ASSERT_OR_EXECUTE(!did_handshake_ && socket_, { return ResultNoSocket; }); | ||||
|         ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; }); | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             did_set_host_name_, { return ResultInternalError; }, | ||||
|             did_set_host_name, { return ResultInternalError; }, | ||||
|             "Expected SetHostName before DoHandshake"); | ||||
|         Result res = backend_->DoHandshake(); | ||||
|         did_handshake_ = res.IsSuccess(); | ||||
|         Result res = backend->DoHandshake(); | ||||
|         did_handshake = res.IsSuccess(); | ||||
|         return res; | ||||
|     } | ||||
|  | ||||
| @@ -225,7 +225,7 @@ private: | ||||
|             u32 size; | ||||
|             u32 offset; | ||||
|         }; | ||||
|         if (!get_server_cert_chain_) { | ||||
|         if (!get_server_cert_chain) { | ||||
|             // Just return the first one, unencoded. | ||||
|             ASSERT_OR_EXECUTE_MSG( | ||||
|                 !certs.empty(), { return {}; }, "Should be at least one server cert"); | ||||
| @@ -248,9 +248,9 @@ private: | ||||
|     } | ||||
|  | ||||
|     ResultVal<std::vector<u8>> ReadImpl(size_t size) { | ||||
|         ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; }); | ||||
|         ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); | ||||
|         std::vector<u8> res(size); | ||||
|         ResultVal<size_t> actual = backend_->Read(res); | ||||
|         ResultVal<size_t> actual = backend->Read(res); | ||||
|         if (actual.Failed()) { | ||||
|             return actual.Code(); | ||||
|         } | ||||
| @@ -259,8 +259,8 @@ private: | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> WriteImpl(std::span<const u8> data) { | ||||
|         ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; }); | ||||
|         return backend_->Write(data); | ||||
|         ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); | ||||
|         return backend->Write(data); | ||||
|     } | ||||
|  | ||||
|     ResultVal<s32> PendingImpl() { | ||||
| @@ -295,7 +295,7 @@ private: | ||||
|     void SetIoMode(HLERequestContext& ctx) { | ||||
|         IPC::RequestParser rp{ctx}; | ||||
|         const u32 mode = rp.Pop<u32>(); | ||||
|         const Result res = SetIOModeImpl(mode); | ||||
|         const Result res = SetIoModeImpl(mode); | ||||
|         IPC::ResponseBuilder rb{ctx, 2}; | ||||
|         rb.Push(res); | ||||
|     } | ||||
| @@ -307,22 +307,26 @@ private: | ||||
|     } | ||||
|  | ||||
|     void DoHandshakeGetServerCert(HLERequestContext& ctx) { | ||||
|         struct OutputParameters { | ||||
|             u32 certs_size; | ||||
|             u32 certs_count; | ||||
|         }; | ||||
|         static_assert(sizeof(OutputParameters) == 0x8); | ||||
|  | ||||
|         const Result res = DoHandshakeImpl(); | ||||
|         u32 certs_count = 0; | ||||
|         u32 certs_size = 0; | ||||
|         OutputParameters out{}; | ||||
|         if (res == ResultSuccess) { | ||||
|             auto certs = backend_->GetServerCerts(); | ||||
|             auto certs = backend->GetServerCerts(); | ||||
|             if (certs.Succeeded()) { | ||||
|                 const std::vector<u8> certs_buf = SerializeServerCerts(*certs); | ||||
|                 ctx.WriteBuffer(certs_buf); | ||||
|                 certs_count = static_cast<u32>(certs->size()); | ||||
|                 certs_size = static_cast<u32>(certs_buf.size()); | ||||
|                 out.certs_count = static_cast<u32>(certs->size()); | ||||
|                 out.certs_size = static_cast<u32>(certs_buf.size()); | ||||
|             } | ||||
|         } | ||||
|         IPC::ResponseBuilder rb{ctx, 4}; | ||||
|         rb.Push(res); | ||||
|         rb.Push(certs_size); | ||||
|         rb.Push(certs_count); | ||||
|         rb.PushRaw(out); | ||||
|     } | ||||
|  | ||||
|     void Read(HLERequestContext& ctx) { | ||||
| @@ -371,10 +375,10 @@ private: | ||||
|  | ||||
|         switch (parameters.option) { | ||||
|         case OptionType::DoNotCloseSocket: | ||||
|             do_not_close_socket_ = static_cast<bool>(parameters.value); | ||||
|             do_not_close_socket = static_cast<bool>(parameters.value); | ||||
|             break; | ||||
|         case OptionType::GetServerCertChain: | ||||
|             get_server_cert_chain_ = static_cast<bool>(parameters.value); | ||||
|             get_server_cert_chain = static_cast<bool>(parameters.value); | ||||
|             break; | ||||
|         default: | ||||
|             LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option, | ||||
| @@ -390,7 +394,7 @@ class ISslContext final : public ServiceFramework<ISslContext> { | ||||
| public: | ||||
|     explicit ISslContext(Core::System& system_, SslVersion version) | ||||
|         : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, | ||||
|           shared_data_{std::make_shared<SslContextSharedData>()} { | ||||
|           shared_data{std::make_shared<SslContextSharedData>()} { | ||||
|         static const FunctionInfo functions[] = { | ||||
|             {0, &ISslContext::SetOption, "SetOption"}, | ||||
|             {1, nullptr, "GetOption"}, | ||||
| @@ -412,7 +416,7 @@ public: | ||||
|  | ||||
| private: | ||||
|     SslVersion ssl_version; | ||||
|     std::shared_ptr<SslContextSharedData> shared_data_; | ||||
|     std::shared_ptr<SslContextSharedData> shared_data; | ||||
|  | ||||
|     void SetOption(HLERequestContext& ctx) { | ||||
|         struct Parameters { | ||||
| @@ -439,17 +443,17 @@ private: | ||||
|         IPC::ResponseBuilder rb{ctx, 2, 0, 1}; | ||||
|         rb.Push(backend_res.Code()); | ||||
|         if (backend_res.Succeeded()) { | ||||
|             rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data_, | ||||
|             rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data, | ||||
|                                                 std::move(*backend_res)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     void GetConnectionCount(HLERequestContext& ctx) { | ||||
|         LOG_WARNING(Service_SSL, "connection_count={}", shared_data_->connection_count); | ||||
|         LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count); | ||||
|  | ||||
|         IPC::ResponseBuilder rb{ctx, 3}; | ||||
|         rb.Push(ResultSuccess); | ||||
|         rb.Push(shared_data_->connection_count); | ||||
|         rb.Push(shared_data->connection_count); | ||||
|     } | ||||
|  | ||||
|     void ImportServerPki(HLERequestContext& ctx) { | ||||
|   | ||||
| @@ -51,37 +51,37 @@ public: | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|  | ||||
|         ssl_ = SSL_new(ssl_ctx); | ||||
|         if (!ssl_) { | ||||
|         ssl = SSL_new(ssl_ctx); | ||||
|         if (!ssl) { | ||||
|             LOG_ERROR(Service_SSL, "SSL_new failed"); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
|  | ||||
|         SSL_set_connect_state(ssl_); | ||||
|         SSL_set_connect_state(ssl); | ||||
|  | ||||
|         bio_ = BIO_new(bio_meth); | ||||
|         if (!bio_) { | ||||
|         bio = BIO_new(bio_meth); | ||||
|         if (!bio) { | ||||
|             LOG_ERROR(Service_SSL, "BIO_new failed"); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
|  | ||||
|         BIO_set_data(bio_, this); | ||||
|         BIO_set_init(bio_, 1); | ||||
|         SSL_set_bio(ssl_, bio_, bio_); | ||||
|         BIO_set_data(bio, this); | ||||
|         BIO_set_init(bio, 1); | ||||
|         SSL_set_bio(ssl, bio, bio); | ||||
|  | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     void SetSocket(std::shared_ptr<Network::SocketBase> socket) override { | ||||
|         socket_ = socket; | ||||
|     void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { | ||||
|         socket = std::move(socket_in); | ||||
|     } | ||||
|  | ||||
|     Result SetHostName(const std::string& hostname) override { | ||||
|         if (!SSL_set1_host(ssl_, hostname.c_str())) { // hostname for verification | ||||
|         if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification | ||||
|             LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
|         if (!SSL_set_tlsext_host_name(ssl_, hostname.c_str())) { // hostname for SNI | ||||
|         if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI | ||||
|             LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
| @@ -89,18 +89,18 @@ public: | ||||
|     } | ||||
|  | ||||
|     Result DoHandshake() override { | ||||
|         SSL_set_verify_result(ssl_, X509_V_OK); | ||||
|         const int ret = SSL_do_handshake(ssl_); | ||||
|         const long verify_result = SSL_get_verify_result(ssl_); | ||||
|         SSL_set_verify_result(ssl, X509_V_OK); | ||||
|         const int ret = SSL_do_handshake(ssl); | ||||
|         const long verify_result = SSL_get_verify_result(ssl); | ||||
|         if (verify_result != X509_V_OK) { | ||||
|             LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", | ||||
|                       X509_verify_cert_error_string(verify_result)); | ||||
|             return CheckOpenSSLErrors(); | ||||
|         } | ||||
|         if (ret <= 0) { | ||||
|             const int ssl_err = SSL_get_error(ssl_, ret); | ||||
|             const int ssl_err = SSL_get_error(ssl, ret); | ||||
|             if (ssl_err == SSL_ERROR_ZERO_RETURN || | ||||
|                 (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_)) { | ||||
|                 (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) { | ||||
|                 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); | ||||
|                 return ResultInternalError; | ||||
|             } | ||||
| @@ -110,18 +110,18 @@ public: | ||||
|  | ||||
|     ResultVal<size_t> Read(std::span<u8> data) override { | ||||
|         size_t actual; | ||||
|         const int ret = SSL_read_ex(ssl_, data.data(), data.size(), &actual); | ||||
|         const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual); | ||||
|         return HandleReturn("SSL_read_ex", actual, ret); | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> Write(std::span<const u8> data) override { | ||||
|         size_t actual; | ||||
|         const int ret = SSL_write_ex(ssl_, data.data(), data.size(), &actual); | ||||
|         const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual); | ||||
|         return HandleReturn("SSL_write_ex", actual, ret); | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { | ||||
|         const int ssl_err = SSL_get_error(ssl_, ret); | ||||
|         const int ssl_err = SSL_get_error(ssl, ret); | ||||
|         CheckOpenSSLErrors(); | ||||
|         switch (ssl_err) { | ||||
|         case SSL_ERROR_NONE: | ||||
| @@ -137,7 +137,7 @@ public: | ||||
|             LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); | ||||
|             return ResultWouldBlock; | ||||
|         default: | ||||
|             if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_) { | ||||
|             if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) { | ||||
|                 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); | ||||
|                 return size_t(0); | ||||
|             } | ||||
| @@ -147,7 +147,7 @@ public: | ||||
|     } | ||||
|  | ||||
|     ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||||
|         STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_); | ||||
|         STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); | ||||
|         if (!chain) { | ||||
|             LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); | ||||
|             return ResultInternalError; | ||||
| @@ -169,8 +169,8 @@ public: | ||||
|  | ||||
|     ~SSLConnectionBackendOpenSSL() { | ||||
|         // these are null-tolerant: | ||||
|         SSL_free(ssl_); | ||||
|         BIO_free(bio_); | ||||
|         SSL_free(ssl); | ||||
|         BIO_free(bio); | ||||
|     } | ||||
|  | ||||
|     static void KeyLogCallback(const SSL* ssl, const char* line) { | ||||
| @@ -188,9 +188,9 @@ public: | ||||
|     static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { | ||||
|         auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             self->socket_, { return 0; }, "OpenSSL asked to send but we have no socket"); | ||||
|             self->socket, { return 0; }, "OpenSSL asked to send but we have no socket"); | ||||
|         BIO_clear_retry_flags(bio); | ||||
|         auto [actual, err] = self->socket_->Send({reinterpret_cast<const u8*>(buf), len}, 0); | ||||
|         auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0); | ||||
|         switch (err) { | ||||
|         case Network::Errno::SUCCESS: | ||||
|             *actual_p = actual; | ||||
| @@ -207,14 +207,14 @@ public: | ||||
|     static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { | ||||
|         auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             self->socket_, { return 0; }, "OpenSSL asked to recv but we have no socket"); | ||||
|             self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket"); | ||||
|         BIO_clear_retry_flags(bio); | ||||
|         auto [actual, err] = self->socket_->Recv(0, {reinterpret_cast<u8*>(buf), len}); | ||||
|         auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len}); | ||||
|         switch (err) { | ||||
|         case Network::Errno::SUCCESS: | ||||
|             *actual_p = actual; | ||||
|             if (actual == 0) { | ||||
|                 self->got_read_eof_ = true; | ||||
|                 self->got_read_eof = true; | ||||
|             } | ||||
|             return actual ? 1 : 0; | ||||
|         case Network::Errno::AGAIN: | ||||
| @@ -246,11 +246,11 @@ public: | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     SSL* ssl_ = nullptr; | ||||
|     BIO* bio_ = nullptr; | ||||
|     bool got_read_eof_ = false; | ||||
|     SSL* ssl = nullptr; | ||||
|     BIO* bio = nullptr; | ||||
|     bool got_read_eof = false; | ||||
|  | ||||
|     std::shared_ptr<Network::SocketBase> socket_; | ||||
|     std::shared_ptr<Network::SocketBase> socket; | ||||
| }; | ||||
|  | ||||
| ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||||
|   | ||||
| @@ -48,6 +48,12 @@ static void OneTimeInit() { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     if (getenv("SSLKEYLOGFILE")) { | ||||
|         LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting " | ||||
|                                   "keys; not logging keys!"); | ||||
|         // Not fatal. | ||||
|     } | ||||
|  | ||||
|     one_time_init_success = true; | ||||
| } | ||||
|  | ||||
| @@ -70,25 +76,25 @@ public: | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     void SetSocket(std::shared_ptr<Network::SocketBase> socket) override { | ||||
|         socket_ = socket; | ||||
|     void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { | ||||
|         socket = std::move(socket_in); | ||||
|     } | ||||
|  | ||||
|     Result SetHostName(const std::string& hostname) override { | ||||
|         hostname_ = hostname; | ||||
|     Result SetHostName(const std::string& hostname_in) override { | ||||
|         hostname = hostname_in; | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     Result DoHandshake() override { | ||||
|         while (1) { | ||||
|             Result r; | ||||
|             switch (handshake_state_) { | ||||
|             switch (handshake_state) { | ||||
|             case HandshakeState::Initial: | ||||
|                 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || | ||||
|                     (r = CallInitializeSecurityContext()) != ResultSuccess) { | ||||
|                     return r; | ||||
|                 } | ||||
|                 // CallInitializeSecurityContext updated `handshake_state_`. | ||||
|                 // CallInitializeSecurityContext updated `handshake_state`. | ||||
|                 continue; | ||||
|             case HandshakeState::ContinueNeeded: | ||||
|             case HandshakeState::IncompleteMessage: | ||||
| @@ -96,20 +102,20 @@ public: | ||||
|                     (r = FillCiphertextReadBuf()) != ResultSuccess) { | ||||
|                     return r; | ||||
|                 } | ||||
|                 if (ciphertext_read_buf_.empty()) { | ||||
|                 if (ciphertext_read_buf.empty()) { | ||||
|                     LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); | ||||
|                     return ResultInternalError; | ||||
|                 } | ||||
|                 if ((r = CallInitializeSecurityContext()) != ResultSuccess) { | ||||
|                     return r; | ||||
|                 } | ||||
|                 // CallInitializeSecurityContext updated `handshake_state_`. | ||||
|                 // CallInitializeSecurityContext updated `handshake_state`. | ||||
|                 continue; | ||||
|             case HandshakeState::DoneAfterFlush: | ||||
|                 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { | ||||
|                     return r; | ||||
|                 } | ||||
|                 handshake_state_ = HandshakeState::Connected; | ||||
|                 handshake_state = HandshakeState::Connected; | ||||
|                 return ResultSuccess; | ||||
|             case HandshakeState::Connected: | ||||
|                 LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); | ||||
| @@ -121,24 +127,24 @@ public: | ||||
|     } | ||||
|  | ||||
|     Result FillCiphertextReadBuf() { | ||||
|         const size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096; | ||||
|         read_buf_fill_size_ = 0; | ||||
|         const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096; | ||||
|         read_buf_fill_size = 0; | ||||
|         // This unnecessarily zeroes the buffer; oh well. | ||||
|         const size_t offset = ciphertext_read_buf_.size(); | ||||
|         const size_t offset = ciphertext_read_buf.size(); | ||||
|         ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); | ||||
|         ciphertext_read_buf_.resize(offset + fill_size, 0); | ||||
|         const auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size); | ||||
|         const auto [actual, err] = socket_->Recv(0, read_span); | ||||
|         ciphertext_read_buf.resize(offset + fill_size, 0); | ||||
|         const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size); | ||||
|         const auto [actual, err] = socket->Recv(0, read_span); | ||||
|         switch (err) { | ||||
|         case Network::Errno::SUCCESS: | ||||
|             ASSERT(static_cast<size_t>(actual) <= fill_size); | ||||
|             ciphertext_read_buf_.resize(offset + actual); | ||||
|             ciphertext_read_buf.resize(offset + actual); | ||||
|             return ResultSuccess; | ||||
|         case Network::Errno::AGAIN: | ||||
|             ciphertext_read_buf_.resize(offset); | ||||
|             ciphertext_read_buf.resize(offset); | ||||
|             return ResultWouldBlock; | ||||
|         default: | ||||
|             ciphertext_read_buf_.resize(offset); | ||||
|             ciphertext_read_buf.resize(offset); | ||||
|             LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
| @@ -146,13 +152,13 @@ public: | ||||
|  | ||||
|     // Returns success if the write buffer has been completely emptied. | ||||
|     Result FlushCiphertextWriteBuf() { | ||||
|         while (!ciphertext_write_buf_.empty()) { | ||||
|             const auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0); | ||||
|         while (!ciphertext_write_buf.empty()) { | ||||
|             const auto [actual, err] = socket->Send(ciphertext_write_buf, 0); | ||||
|             switch (err) { | ||||
|             case Network::Errno::SUCCESS: | ||||
|                 ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size()); | ||||
|                 ciphertext_write_buf_.erase(ciphertext_write_buf_.begin(), | ||||
|                                             ciphertext_write_buf_.begin() + actual); | ||||
|                 ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size()); | ||||
|                 ciphertext_write_buf.erase(ciphertext_write_buf.begin(), | ||||
|                                            ciphertext_write_buf.begin() + actual); | ||||
|                 break; | ||||
|             case Network::Errno::AGAIN: | ||||
|                 return ResultWouldBlock; | ||||
| @@ -175,9 +181,9 @@ public: | ||||
|             // only used if `initial_call_done` | ||||
|             { | ||||
|                 // [0] | ||||
|                 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()), | ||||
|                 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), | ||||
|                 .BufferType = SECBUFFER_TOKEN, | ||||
|                 .pvBuffer = ciphertext_read_buf_.data(), | ||||
|                 .pvBuffer = ciphertext_read_buf.data(), | ||||
|             }, | ||||
|             { | ||||
|                 // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is | ||||
| @@ -211,30 +217,30 @@ public: | ||||
|             .pBuffers = output_buffers.data(), | ||||
|         }; | ||||
|         ASSERT_OR_EXECUTE_MSG( | ||||
|             input_buffers[0].cbBuffer == ciphertext_read_buf_.size(), | ||||
|             input_buffers[0].cbBuffer == ciphertext_read_buf.size(), | ||||
|             { return ResultInternalError; }, "read buffer too large"); | ||||
|  | ||||
|         bool initial_call_done = handshake_state_ != HandshakeState::Initial; | ||||
|         bool initial_call_done = handshake_state != HandshakeState::Initial; | ||||
|         if (initial_call_done) { | ||||
|             LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", | ||||
|                       ciphertext_read_buf_.size()); | ||||
|                       ciphertext_read_buf.size()); | ||||
|         } | ||||
|  | ||||
|         const SECURITY_STATUS ret = | ||||
|             InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr, | ||||
|             InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, | ||||
|                                        // Caller ensured we have set a hostname: | ||||
|                                        const_cast<char*>(hostname_.value().c_str()), req, | ||||
|                                        const_cast<char*>(hostname.value().c_str()), req, | ||||
|                                        0, // Reserved1 | ||||
|                                        0, // TargetDataRep not used with Schannel | ||||
|                                        initial_call_done ? &input_desc : nullptr, | ||||
|                                        0, // Reserved2 | ||||
|                                        initial_call_done ? nullptr : &ctxt_, &output_desc, &attr, | ||||
|                                        initial_call_done ? nullptr : &ctxt, &output_desc, &attr, | ||||
|                                        nullptr); // ptsExpiry | ||||
|  | ||||
|         if (output_buffers[0].pvBuffer) { | ||||
|             const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), | ||||
|                                  output_buffers[0].cbBuffer); | ||||
|             ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end()); | ||||
|             ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end()); | ||||
|             FreeContextBuffer(output_buffers[0].pvBuffer); | ||||
|         } | ||||
|  | ||||
| @@ -251,64 +257,64 @@ public: | ||||
|             LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); | ||||
|             if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { | ||||
|                 LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); | ||||
|                 ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf_.size()); | ||||
|                 ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(), | ||||
|                                            ciphertext_read_buf_.end() - input_buffers[1].cbBuffer); | ||||
|                 ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size()); | ||||
|                 ciphertext_read_buf.erase(ciphertext_read_buf.begin(), | ||||
|                                           ciphertext_read_buf.end() - input_buffers[1].cbBuffer); | ||||
|             } else { | ||||
|                 ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); | ||||
|                 ciphertext_read_buf_.clear(); | ||||
|                 ciphertext_read_buf.clear(); | ||||
|             } | ||||
|             handshake_state_ = HandshakeState::ContinueNeeded; | ||||
|             handshake_state = HandshakeState::ContinueNeeded; | ||||
|             return ResultSuccess; | ||||
|         case SEC_E_INCOMPLETE_MESSAGE: | ||||
|             LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); | ||||
|             ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); | ||||
|             read_buf_fill_size_ = input_buffers[1].cbBuffer; | ||||
|             handshake_state_ = HandshakeState::IncompleteMessage; | ||||
|             read_buf_fill_size = input_buffers[1].cbBuffer; | ||||
|             handshake_state = HandshakeState::IncompleteMessage; | ||||
|             return ResultSuccess; | ||||
|         case SEC_E_OK: | ||||
|             LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); | ||||
|             ciphertext_read_buf_.clear(); | ||||
|             handshake_state_ = HandshakeState::DoneAfterFlush; | ||||
|             ciphertext_read_buf.clear(); | ||||
|             handshake_state = HandshakeState::DoneAfterFlush; | ||||
|             return GrabStreamSizes(); | ||||
|         default: | ||||
|             LOG_ERROR(Service_SSL, | ||||
|                       "InitializeSecurityContext failed (probably certificate/protocol issue): {}", | ||||
|                       Common::NativeErrorToString(ret)); | ||||
|             handshake_state_ = HandshakeState::Error; | ||||
|             handshake_state = HandshakeState::Error; | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Result GrabStreamSizes() { | ||||
|         const SECURITY_STATUS ret = | ||||
|             QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_); | ||||
|             QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes); | ||||
|         if (ret != SEC_E_OK) { | ||||
|             LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", | ||||
|                       Common::NativeErrorToString(ret)); | ||||
|             handshake_state_ = HandshakeState::Error; | ||||
|             handshake_state = HandshakeState::Error; | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         return ResultSuccess; | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> Read(std::span<u8> data) override { | ||||
|         if (handshake_state_ != HandshakeState::Connected) { | ||||
|         if (handshake_state != HandshakeState::Connected) { | ||||
|             LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         if (data.size() == 0 || got_read_eof_) { | ||||
|         if (data.size() == 0 || got_read_eof) { | ||||
|             return size_t(0); | ||||
|         } | ||||
|         while (1) { | ||||
|             if (!cleartext_read_buf_.empty()) { | ||||
|                 const size_t read_size = std::min(cleartext_read_buf_.size(), data.size()); | ||||
|                 std::memcpy(data.data(), cleartext_read_buf_.data(), read_size); | ||||
|                 cleartext_read_buf_.erase(cleartext_read_buf_.begin(), | ||||
|                                           cleartext_read_buf_.begin() + read_size); | ||||
|             if (!cleartext_read_buf.empty()) { | ||||
|                 const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); | ||||
|                 std::memcpy(data.data(), cleartext_read_buf.data(), read_size); | ||||
|                 cleartext_read_buf.erase(cleartext_read_buf.begin(), | ||||
|                                          cleartext_read_buf.begin() + read_size); | ||||
|                 return read_size; | ||||
|             } | ||||
|             if (!ciphertext_read_buf_.empty()) { | ||||
|             if (!ciphertext_read_buf.empty()) { | ||||
|                 SecBuffer empty{ | ||||
|                     .cbBuffer = 0, | ||||
|                     .BufferType = SECBUFFER_EMPTY, | ||||
| @@ -316,16 +322,16 @@ public: | ||||
|                 }; | ||||
|                 std::array<SecBuffer, 5> buffers{{ | ||||
|                     { | ||||
|                         .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()), | ||||
|                         .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), | ||||
|                         .BufferType = SECBUFFER_DATA, | ||||
|                         .pvBuffer = ciphertext_read_buf_.data(), | ||||
|                         .pvBuffer = ciphertext_read_buf.data(), | ||||
|                     }, | ||||
|                     empty, | ||||
|                     empty, | ||||
|                     empty, | ||||
|                 }}; | ||||
|                 ASSERT_OR_EXECUTE_MSG( | ||||
|                     buffers[0].cbBuffer == ciphertext_read_buf_.size(), | ||||
|                     buffers[0].cbBuffer == ciphertext_read_buf.size(), | ||||
|                     { return ResultInternalError; }, "read buffer too large"); | ||||
|                 SecBufferDesc desc{ | ||||
|                     .ulVersion = SECBUFFER_VERSION, | ||||
| @@ -333,7 +339,7 @@ public: | ||||
|                     .pBuffers = buffers.data(), | ||||
|                 }; | ||||
|                 SECURITY_STATUS ret = | ||||
|                     DecryptMessage(&ctxt_, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); | ||||
|                     DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); | ||||
|                 switch (ret) { | ||||
|                 case SEC_E_OK: | ||||
|                     ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, | ||||
| @@ -342,24 +348,23 @@ public: | ||||
|                                       { return ResultInternalError; }); | ||||
|                     ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, | ||||
|                                       { return ResultInternalError; }); | ||||
|                     cleartext_read_buf_.assign(static_cast<u8*>(buffers[1].pvBuffer), | ||||
|                                                static_cast<u8*>(buffers[1].pvBuffer) + | ||||
|                                                    buffers[1].cbBuffer); | ||||
|                     cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer), | ||||
|                                               static_cast<u8*>(buffers[1].pvBuffer) + | ||||
|                                                   buffers[1].cbBuffer); | ||||
|                     if (buffers[3].BufferType == SECBUFFER_EXTRA) { | ||||
|                         ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf_.size()); | ||||
|                         ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(), | ||||
|                                                    ciphertext_read_buf_.end() - | ||||
|                                                        buffers[3].cbBuffer); | ||||
|                         ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size()); | ||||
|                         ciphertext_read_buf.erase(ciphertext_read_buf.begin(), | ||||
|                                                   ciphertext_read_buf.end() - buffers[3].cbBuffer); | ||||
|                     } else { | ||||
|                         ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); | ||||
|                         ciphertext_read_buf_.clear(); | ||||
|                         ciphertext_read_buf.clear(); | ||||
|                     } | ||||
|                     continue; | ||||
|                 case SEC_E_INCOMPLETE_MESSAGE: | ||||
|                     break; | ||||
|                 case SEC_I_CONTEXT_EXPIRED: | ||||
|                     // Server hung up by sending close_notify. | ||||
|                     got_read_eof_ = true; | ||||
|                     got_read_eof = true; | ||||
|                     return size_t(0); | ||||
|                 default: | ||||
|                     LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", | ||||
| @@ -371,43 +376,43 @@ public: | ||||
|             if (r != ResultSuccess) { | ||||
|                 return r; | ||||
|             } | ||||
|             if (ciphertext_read_buf_.empty()) { | ||||
|                 got_read_eof_ = true; | ||||
|             if (ciphertext_read_buf.empty()) { | ||||
|                 got_read_eof = true; | ||||
|                 return size_t(0); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     ResultVal<size_t> Write(std::span<const u8> data) override { | ||||
|         if (handshake_state_ != HandshakeState::Connected) { | ||||
|         if (handshake_state != HandshakeState::Connected) { | ||||
|             LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         if (data.size() == 0) { | ||||
|             return size_t(0); | ||||
|         } | ||||
|         data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes_.cbMaximumMessage)); | ||||
|         if (!cleartext_write_buf_.empty()) { | ||||
|         data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage)); | ||||
|         if (!cleartext_write_buf.empty()) { | ||||
|             // Already in the middle of a write.  It wouldn't make sense to not | ||||
|             // finish sending the entire buffer since TLS has | ||||
|             // header/MAC/padding/etc. | ||||
|             if (data.size() != cleartext_write_buf_.size() || | ||||
|                 std::memcmp(data.data(), cleartext_write_buf_.data(), data.size())) { | ||||
|             if (data.size() != cleartext_write_buf.size() || | ||||
|                 std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) { | ||||
|                 LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); | ||||
|                 return ResultInternalError; | ||||
|             } | ||||
|             return WriteAlreadyEncryptedData(); | ||||
|         } else { | ||||
|             cleartext_write_buf_.assign(data.begin(), data.end()); | ||||
|             cleartext_write_buf.assign(data.begin(), data.end()); | ||||
|         } | ||||
|  | ||||
|         std::vector<u8> header_buf(stream_sizes_.cbHeader, 0); | ||||
|         std::vector<u8> tmp_data_buf = cleartext_write_buf_; | ||||
|         std::vector<u8> trailer_buf(stream_sizes_.cbTrailer, 0); | ||||
|         std::vector<u8> header_buf(stream_sizes.cbHeader, 0); | ||||
|         std::vector<u8> tmp_data_buf = cleartext_write_buf; | ||||
|         std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0); | ||||
|  | ||||
|         std::array<SecBuffer, 3> buffers{{ | ||||
|             { | ||||
|                 .cbBuffer = stream_sizes_.cbHeader, | ||||
|                 .cbBuffer = stream_sizes.cbHeader, | ||||
|                 .BufferType = SECBUFFER_STREAM_HEADER, | ||||
|                 .pvBuffer = header_buf.data(), | ||||
|             }, | ||||
| @@ -417,7 +422,7 @@ public: | ||||
|                 .pvBuffer = tmp_data_buf.data(), | ||||
|             }, | ||||
|             { | ||||
|                 .cbBuffer = stream_sizes_.cbTrailer, | ||||
|                 .cbBuffer = stream_sizes.cbTrailer, | ||||
|                 .BufferType = SECBUFFER_STREAM_TRAILER, | ||||
|                 .pvBuffer = trailer_buf.data(), | ||||
|             }, | ||||
| @@ -431,17 +436,17 @@ public: | ||||
|             .pBuffers = buffers.data(), | ||||
|         }; | ||||
|  | ||||
|         const SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); | ||||
|         const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); | ||||
|         if (ret != SEC_E_OK) { | ||||
|             LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); | ||||
|             return ResultInternalError; | ||||
|         } | ||||
|         ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), header_buf.begin(), | ||||
|                                      header_buf.end()); | ||||
|         ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), tmp_data_buf.begin(), | ||||
|                                      tmp_data_buf.end()); | ||||
|         ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), trailer_buf.begin(), | ||||
|                                      trailer_buf.end()); | ||||
|         ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(), | ||||
|                                     header_buf.end()); | ||||
|         ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(), | ||||
|                                     tmp_data_buf.end()); | ||||
|         ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), | ||||
|                                     trailer_buf.end()); | ||||
|         return WriteAlreadyEncryptedData(); | ||||
|     } | ||||
|  | ||||
| @@ -451,15 +456,15 @@ public: | ||||
|             return r; | ||||
|         } | ||||
|         // write buf is empty | ||||
|         const size_t cleartext_bytes_written = cleartext_write_buf_.size(); | ||||
|         cleartext_write_buf_.clear(); | ||||
|         const size_t cleartext_bytes_written = cleartext_write_buf.size(); | ||||
|         cleartext_write_buf.clear(); | ||||
|         return cleartext_bytes_written; | ||||
|     } | ||||
|  | ||||
|     ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||||
|         PCCERT_CONTEXT returned_cert = nullptr; | ||||
|         const SECURITY_STATUS ret = | ||||
|             QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); | ||||
|             QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); | ||||
|         if (ret != SEC_E_OK) { | ||||
|             LOG_ERROR(Service_SSL, | ||||
|                       "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", | ||||
| @@ -480,8 +485,8 @@ public: | ||||
|     } | ||||
|  | ||||
|     ~SSLConnectionBackendSchannel() { | ||||
|         if (handshake_state_ != HandshakeState::Initial) { | ||||
|             DeleteSecurityContext(&ctxt_); | ||||
|         if (handshake_state != HandshakeState::Initial) { | ||||
|             DeleteSecurityContext(&ctxt); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -509,21 +514,21 @@ public: | ||||
|         // Another error was returned and we shouldn't allow initialization | ||||
|         // to continue. | ||||
|         Error, | ||||
|     } handshake_state_ = HandshakeState::Initial; | ||||
|     } handshake_state = HandshakeState::Initial; | ||||
|  | ||||
|     CtxtHandle ctxt_; | ||||
|     SecPkgContext_StreamSizes stream_sizes_; | ||||
|     CtxtHandle ctxt; | ||||
|     SecPkgContext_StreamSizes stream_sizes; | ||||
|  | ||||
|     std::shared_ptr<Network::SocketBase> socket_; | ||||
|     std::optional<std::string> hostname_; | ||||
|     std::shared_ptr<Network::SocketBase> socket; | ||||
|     std::optional<std::string> hostname; | ||||
|  | ||||
|     std::vector<u8> ciphertext_read_buf_; | ||||
|     std::vector<u8> ciphertext_write_buf_; | ||||
|     std::vector<u8> cleartext_read_buf_; | ||||
|     std::vector<u8> cleartext_write_buf_; | ||||
|     std::vector<u8> ciphertext_read_buf; | ||||
|     std::vector<u8> ciphertext_write_buf; | ||||
|     std::vector<u8> cleartext_read_buf; | ||||
|     std::vector<u8> cleartext_write_buf; | ||||
|  | ||||
|     bool got_read_eof_ = false; | ||||
|     size_t read_buf_fill_size_ = 0; | ||||
|     bool got_read_eof = false; | ||||
|     size_t read_buf_fill_size = 0; | ||||
| }; | ||||
|  | ||||
| ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user