| | 1 | | // Copyright (c) ZeroC, Inc. |
| | 2 | |
|
| | 3 | | using System.Buffers; |
| | 4 | | using System.Diagnostics; |
| | 5 | | using System.Net; |
| | 6 | | using System.Net.Security; |
| | 7 | | using System.Net.Sockets; |
| | 8 | | using System.Runtime.InteropServices; |
| | 9 | |
|
| | 10 | | namespace IceRpc.Transports.Tcp.Internal; |
| | 11 | |
|
| | 12 | | /// <summary>Implements <see cref="IDuplexConnection" /> for tcp with or without TLS.</summary> |
| | 13 | | /// <remarks>Unlike Coloc, the Tcp transport is not a "checked" transport, which means it does not need to detect |
| | 14 | | /// violations of the duplex transport contract or report such violations. It assumes its clients are sufficiently well |
| | 15 | | /// tested to never violate this contract. As a result, this implementation does not throw |
| | 16 | | /// <see cref="InvalidOperationException" />.</remarks> |
| | 17 | | internal abstract class TcpConnection : IDuplexConnection |
| | 18 | | { |
| | 19 | | internal abstract Socket Socket { get; } |
| | 20 | |
|
| | 21 | | internal abstract SslStream? SslStream { get; } |
| | 22 | |
|
| | 23 | | private protected volatile bool _isDisposed; |
| | 24 | |
|
| | 25 | | // The MaxDataSize of the SSL implementation. |
| | 26 | | private const int MaxSslDataSize = 16 * 1024; |
| | 27 | |
|
| | 28 | | private bool _isShutdown; |
| | 29 | | private readonly int _maxSslBufferSize; |
| | 30 | | private readonly List<ArraySegment<byte>> _segments = new(); |
| | 31 | | private readonly IMemoryOwner<byte>? _writeBufferOwner; |
| | 32 | |
|
| | 33 | | public Task<TransportConnectionInformation> ConnectAsync(CancellationToken cancellationToken) |
| | 34 | | { |
| | 35 | | ObjectDisposedException.ThrowIf(_isDisposed, this); |
| | 36 | | return ConnectAsyncCore(cancellationToken); |
| | 37 | | } |
| | 38 | |
|
| | 39 | | public void Dispose() |
| | 40 | | { |
| | 41 | | _isDisposed = true; |
| | 42 | |
|
| | 43 | | if (SslStream is SslStream sslStream) |
| | 44 | | { |
| | 45 | | sslStream.Dispose(); |
| | 46 | | } |
| | 47 | |
|
| | 48 | | // If shutdown was called, we can just dispose the connection to complete the graceful TCP closure. Otherwise, |
| | 49 | | // we abort the TCP connection to ensure the connection doesn't end up in the TIME_WAIT state. |
| | 50 | | if (_isShutdown) |
| | 51 | | { |
| | 52 | | Socket.Dispose(); |
| | 53 | | } |
| | 54 | | else |
| | 55 | | { |
| | 56 | | Socket.Close(0); |
| | 57 | | } |
| | 58 | | _writeBufferOwner?.Dispose(); |
| | 59 | | } |
| | 60 | |
|
| | 61 | | public ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken) |
| | 62 | | { |
| | 63 | | ObjectDisposedException.ThrowIf(_isDisposed, this); |
| | 64 | |
|
| | 65 | | return buffer.Length > 0 ? PerformReadAsync() : |
| | 66 | | throw new ArgumentException($"The {nameof(buffer)} cannot be empty.", nameof(buffer)); |
| | 67 | |
|
| | 68 | | async ValueTask<int> PerformReadAsync() |
| | 69 | | { |
| | 70 | | try |
| | 71 | | { |
| | 72 | | return SslStream is SslStream sslStream ? |
| | 73 | | await SslStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false) : |
| | 74 | | await Socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false); |
| | 75 | | } |
| | 76 | | catch (IOException exception) |
| | 77 | | { |
| | 78 | | throw exception.ToIceRpcException(); |
| | 79 | | } |
| | 80 | | catch (SocketException exception) |
| | 81 | | { |
| | 82 | | throw exception.ToIceRpcException(); |
| | 83 | | } |
| | 84 | | } |
| | 85 | | } |
| | 86 | |
|
| | 87 | | public Task ShutdownWriteAsync(CancellationToken cancellationToken) |
| | 88 | | { |
| | 89 | | ObjectDisposedException.ThrowIf(_isDisposed, this); |
| | 90 | |
|
| | 91 | | return PerformShutdownAsync(); |
| | 92 | |
|
| | 93 | | async Task PerformShutdownAsync() |
| | 94 | | { |
| | 95 | | try |
| | 96 | | { |
| | 97 | | if (SslStream is SslStream sslStream) |
| | 98 | | { |
| | 99 | | Task shutdownTask = sslStream.ShutdownAsync(); |
| | 100 | |
|
| | 101 | | try |
| | 102 | | { |
| | 103 | | await shutdownTask.WaitAsync(cancellationToken).ConfigureAwait(false); |
| | 104 | | } |
| | 105 | | catch (OperationCanceledException) |
| | 106 | | { |
| | 107 | | await AbortAndObserveAsync(shutdownTask).ConfigureAwait(false); |
| | 108 | | throw; |
| | 109 | | } |
| | 110 | | } |
| | 111 | |
|
| | 112 | | // Shutdown the socket send side to send a TCP FIN packet. We don't close the read side because we want |
| | 113 | | // to be notified when the peer shuts down it's side of the socket (through the ReceiveAsync call). |
| | 114 | | Socket.Shutdown(SocketShutdown.Send); |
| | 115 | |
|
| | 116 | | // If shutdown is successful mark the connection as shutdown to ensure Dispose won't reset the TCP |
| | 117 | | // connection. |
| | 118 | | _isShutdown = true; |
| | 119 | | } |
| | 120 | | catch (IOException exception) |
| | 121 | | { |
| | 122 | | throw exception.ToIceRpcException(); |
| | 123 | | } |
| | 124 | | catch (SocketException exception) |
| | 125 | | { |
| | 126 | | throw exception.ToIceRpcException(); |
| | 127 | | } |
| | 128 | | } |
| | 129 | | } |
| | 130 | |
|
| | 131 | | public ValueTask WriteAsync(ReadOnlySequence<byte> buffer, CancellationToken cancellationToken) |
| | 132 | | { |
| | 133 | | ObjectDisposedException.ThrowIf(_isDisposed, this); |
| | 134 | | return PerformWriteAsync(); |
| | 135 | |
|
| | 136 | | async ValueTask PerformWriteAsync() |
| | 137 | | { |
| | 138 | | try |
| | 139 | | { |
| | 140 | | if (SslStream is SslStream sslStream) |
| | 141 | | { |
| | 142 | | if (buffer.IsSingleSegment) |
| | 143 | | { |
| | 144 | | await sslStream.WriteAsync(buffer.First, cancellationToken).ConfigureAwait(false); |
| | 145 | | } |
| | 146 | | else |
| | 147 | | { |
| | 148 | | // Coalesce leading segments up to _maxSslBufferSize. We don't coalesce trailing segments as we |
| | 149 | | // assume these segments are large enough. |
| | 150 | | int leadingSize = 0; |
| | 151 | | int leadingSegmentCount = 0; |
| | 152 | | foreach (ReadOnlyMemory<byte> memory in buffer) |
| | 153 | | { |
| | 154 | | if (leadingSize + memory.Length <= _maxSslBufferSize) |
| | 155 | | { |
| | 156 | | leadingSize += memory.Length; |
| | 157 | | leadingSegmentCount++; |
| | 158 | | } |
| | 159 | | else |
| | 160 | | { |
| | 161 | | break; |
| | 162 | | } |
| | 163 | | } |
| | 164 | |
|
| | 165 | | if (leadingSegmentCount > 1) |
| | 166 | | { |
| | 167 | | ReadOnlySequence<byte> leading = buffer.Slice(0, leadingSize); |
| | 168 | | buffer = buffer.Slice(leadingSize); // buffer can become empty |
| | 169 | |
|
| | 170 | | Debug.Assert(_writeBufferOwner is not null); |
| | 171 | | Memory<byte> writeBuffer = _writeBufferOwner.Memory[0..leadingSize]; |
| | 172 | | leading.CopyTo(writeBuffer.Span); |
| | 173 | |
|
| | 174 | | // Send the "coalesced" leading segments |
| | 175 | | await sslStream.WriteAsync(writeBuffer, cancellationToken).ConfigureAwait(false); |
| | 176 | | } |
| | 177 | | // else no need to coalesce (copy) a single segment |
| | 178 | |
|
| | 179 | | // Send the remaining segments one by one |
| | 180 | | if (buffer.IsEmpty) |
| | 181 | | { |
| | 182 | | // done |
| | 183 | | } |
| | 184 | | else if (buffer.IsSingleSegment) |
| | 185 | | { |
| | 186 | | await sslStream.WriteAsync(buffer.First, cancellationToken).ConfigureAwait(false); |
| | 187 | | } |
| | 188 | | else |
| | 189 | | { |
| | 190 | | foreach (ReadOnlyMemory<byte> memory in buffer) |
| | 191 | | { |
| | 192 | | await sslStream.WriteAsync(memory, cancellationToken).ConfigureAwait(false); |
| | 193 | | } |
| | 194 | | } |
| | 195 | | } |
| | 196 | | } |
| | 197 | | else |
| | 198 | | { |
| | 199 | | if (buffer.IsSingleSegment) |
| | 200 | | { |
| | 201 | | _ = await Socket.SendAsync(buffer.First, SocketFlags.None, cancellationToken) |
| | 202 | | .ConfigureAwait(false); |
| | 203 | | } |
| | 204 | | else |
| | 205 | | { |
| | 206 | | _segments.Clear(); |
| | 207 | | foreach (ReadOnlyMemory<byte> memory in buffer) |
| | 208 | | { |
| | 209 | | if (MemoryMarshal.TryGetArray(memory, out ArraySegment<byte> segment)) |
| | 210 | | { |
| | 211 | | _segments.Add(segment); |
| | 212 | | } |
| | 213 | | else |
| | 214 | | { |
| | 215 | | throw new ArgumentException( |
| | 216 | | $"The {nameof(buffer)} must be backed by arrays.", |
| | 217 | | nameof(buffer)); |
| | 218 | | } |
| | 219 | | } |
| | 220 | |
|
| | 221 | | Task sendTask = Socket.SendAsync(_segments, SocketFlags.None); |
| | 222 | |
|
| | 223 | | try |
| | 224 | | { |
| | 225 | | await sendTask.WaitAsync(cancellationToken).ConfigureAwait(false); |
| | 226 | | } |
| | 227 | | catch (OperationCanceledException) |
| | 228 | | { |
| | 229 | | await AbortAndObserveAsync(sendTask).ConfigureAwait(false); |
| | 230 | | throw; |
| | 231 | | } |
| | 232 | | } |
| | 233 | | } |
| | 234 | | } |
| | 235 | | catch (IOException exception) |
| | 236 | | { |
| | 237 | | throw exception.ToIceRpcException(); |
| | 238 | | } |
| | 239 | | catch (SocketException exception) |
| | 240 | | { |
| | 241 | | throw exception.ToIceRpcException(); |
| | 242 | | } |
| | 243 | | } |
| | 244 | | } |
| | 245 | |
|
| | 246 | | private protected TcpConnection(IMemoryOwner<byte>? memoryOwner) |
| | 247 | | { |
| | 248 | | _writeBufferOwner = memoryOwner; |
| | 249 | | // When coalescing leading buffers in WriteAsync (SSL only), the upper size limit is the lesser of the size of |
| | 250 | | // the buffer we rented from the memory pool (typically 4K) and MaxSslDataSize (16K). |
| | 251 | | _maxSslBufferSize = Math.Min(memoryOwner?.Memory.Length ?? 0, MaxSslDataSize); |
| | 252 | | } |
| | 253 | |
|
| | 254 | | private protected abstract Task<TransportConnectionInformation> ConnectAsyncCore( |
| | 255 | | CancellationToken cancellationToken); |
| | 256 | |
|
| | 257 | | /// <summary>Aborts the connection and then observes the exception of the provided task.</summary> |
| | 258 | | private async Task AbortAndObserveAsync(Task task) |
| | 259 | | { |
| | 260 | | Socket.Close(0); |
| | 261 | | try |
| | 262 | | { |
| | 263 | | await task.ConfigureAwait(false); |
| | 264 | | } |
| | 265 | | catch |
| | 266 | | { |
| | 267 | | // observe exception |
| | 268 | | } |
| | 269 | | } |
| | 270 | | } |
| | 271 | |
|
| | 272 | | internal class TcpClientConnection : TcpConnection |
| | 273 | | { |
| 1732 | 274 | | internal override Socket Socket { get; } |
| | 275 | |
|
| 534 | 276 | | internal override SslStream? SslStream => _sslStream; |
| | 277 | |
|
| | 278 | | private readonly EndPoint _addr; |
| | 279 | | private readonly SslClientAuthenticationOptions? _authenticationOptions; |
| | 280 | |
|
| | 281 | | private SslStream? _sslStream; |
| | 282 | |
|
| | 283 | | internal TcpClientConnection( |
| | 284 | | ServerAddress serverAddress, |
| | 285 | | SslClientAuthenticationOptions? authenticationOptions, |
| | 286 | | MemoryPool<byte> pool, |
| | 287 | | int minimumSegmentSize, |
| | 288 | | TcpClientTransportOptions options) |
| 264 | 289 | | : base(authenticationOptions is not null ? pool.Rent(minimumSegmentSize) : null) |
| 264 | 290 | | { |
| 264 | 291 | | _addr = IPAddress.TryParse(serverAddress.Host, out IPAddress? ipAddress) ? |
| 264 | 292 | | new IPEndPoint(ipAddress, serverAddress.Port) : |
| 264 | 293 | | new DnsEndPoint(serverAddress.Host, serverAddress.Port); |
| | 294 | |
|
| 264 | 295 | | _authenticationOptions = authenticationOptions; |
| | 296 | |
|
| | 297 | | // When using IPv6 address family we use the socket constructor without AddressFamily parameter to ensure |
| | 298 | | // dual-mode socket are used in platforms that support them. |
| 264 | 299 | | Socket = ipAddress?.AddressFamily == AddressFamily.InterNetwork ? |
| 264 | 300 | | new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp) : |
| 264 | 301 | | new Socket(SocketType.Stream, ProtocolType.Tcp); |
| | 302 | |
|
| | 303 | | try |
| 264 | 304 | | { |
| 264 | 305 | | if (options.LocalNetworkAddress is IPEndPoint localNetworkAddress) |
| 2 | 306 | | { |
| 2 | 307 | | Socket.Bind(localNetworkAddress); |
| 2 | 308 | | } |
| | 309 | |
|
| 264 | 310 | | Socket.Configure(options); |
| 264 | 311 | | } |
| 0 | 312 | | catch (SocketException exception) |
| 0 | 313 | | { |
| 0 | 314 | | Socket.Dispose(); |
| 0 | 315 | | throw exception.ToIceRpcException(); |
| | 316 | | } |
| 0 | 317 | | catch |
| 0 | 318 | | { |
| 0 | 319 | | Socket.Dispose(); |
| 0 | 320 | | throw; |
| | 321 | | } |
| 264 | 322 | | } |
| | 323 | |
|
| | 324 | | private protected override async Task<TransportConnectionInformation> ConnectAsyncCore( |
| | 325 | | CancellationToken cancellationToken) |
| 236 | 326 | | { |
| | 327 | | try |
| 236 | 328 | | { |
| 236 | 329 | | Debug.Assert(Socket is not null); |
| | 330 | |
|
| | 331 | | // Connect to the peer. |
| 236 | 332 | | await Socket.ConnectAsync(_addr, cancellationToken).ConfigureAwait(false); |
| | 333 | |
|
| | 334 | | // Workaround: a canceled Socket.ConnectAsync call can return successfully but the Socket is closed because |
| | 335 | | // of the cancellation. See https://github.com/dotnet/runtime/issues/75889. |
| 219 | 336 | | cancellationToken.ThrowIfCancellationRequested(); |
| | 337 | |
|
| 219 | 338 | | if (_authenticationOptions is not null) |
| 54 | 339 | | { |
| 54 | 340 | | _sslStream = new SslStream(new NetworkStream(Socket, false), false); |
| | 341 | |
|
| 54 | 342 | | await _sslStream.AuthenticateAsClientAsync( |
| 54 | 343 | | _authenticationOptions, |
| 54 | 344 | | cancellationToken).ConfigureAwait(false); |
| 44 | 345 | | } |
| | 346 | |
|
| 209 | 347 | | return new TransportConnectionInformation( |
| 209 | 348 | | localNetworkAddress: Socket.LocalEndPoint!, |
| 209 | 349 | | remoteNetworkAddress: Socket.RemoteEndPoint!, |
| 209 | 350 | | _sslStream?.RemoteCertificate); |
| | 351 | | } |
| 2 | 352 | | catch (IOException exception) |
| 2 | 353 | | { |
| 2 | 354 | | throw exception.ToIceRpcException(); |
| | 355 | | } |
| 10 | 356 | | catch (SocketException exception) |
| 10 | 357 | | { |
| 10 | 358 | | throw exception.ToIceRpcException(); |
| | 359 | | } |
| 208 | 360 | | } |
| | 361 | | } |
| | 362 | |
|
| | 363 | | internal class TcpServerConnection : TcpConnection |
| | 364 | | { |
| | 365 | | internal override Socket Socket { get; } |
| | 366 | |
|
| | 367 | | internal override SslStream? SslStream => _sslStream; |
| | 368 | |
|
| | 369 | | private readonly SslServerAuthenticationOptions? _authenticationOptions; |
| | 370 | | private SslStream? _sslStream; |
| | 371 | |
|
| | 372 | | internal TcpServerConnection( |
| | 373 | | Socket socket, |
| | 374 | | SslServerAuthenticationOptions? authenticationOptions, |
| | 375 | | MemoryPool<byte> pool, |
| | 376 | | int minimumSegmentSize) |
| | 377 | | : base(authenticationOptions is not null ? pool.Rent(minimumSegmentSize) : null) |
| | 378 | | { |
| | 379 | | Socket = socket; |
| | 380 | | _authenticationOptions = authenticationOptions; |
| | 381 | | } |
| | 382 | |
|
| | 383 | | private protected override async Task<TransportConnectionInformation> ConnectAsyncCore( |
| | 384 | | CancellationToken cancellationToken) |
| | 385 | | { |
| | 386 | | try |
| | 387 | | { |
| | 388 | | if (_authenticationOptions is not null) |
| | 389 | | { |
| | 390 | | // This can only be created with a connected socket. |
| | 391 | | _sslStream = new SslStream(new NetworkStream(Socket, false), false); |
| | 392 | | await _sslStream.AuthenticateAsServerAsync( |
| | 393 | | _authenticationOptions, |
| | 394 | | cancellationToken).ConfigureAwait(false); |
| | 395 | | } |
| | 396 | |
|
| | 397 | | return new TransportConnectionInformation( |
| | 398 | | localNetworkAddress: Socket.LocalEndPoint!, |
| | 399 | | remoteNetworkAddress: Socket.RemoteEndPoint!, |
| | 400 | | _sslStream?.RemoteCertificate); |
| | 401 | | } |
| | 402 | | catch (IOException exception) |
| | 403 | | { |
| | 404 | | throw exception.ToIceRpcException(); |
| | 405 | | } |
| | 406 | | catch (SocketException exception) |
| | 407 | | { |
| | 408 | | throw exception.ToIceRpcException(); |
| | 409 | | } |
| | 410 | | } |
| | 411 | | } |