< Summary

Information
Class: IceRpc.Transports.Tcp.Internal.TcpServerConnection
Assembly: IceRpc
File(s): /home/runner/work/icerpc-csharp/icerpc-csharp/src/IceRpc/Transports/Tcp/Internal/TcpConnection.cs
Tag: 1321_24790053727
Line coverage
88%
Covered lines: 24
Uncovered lines: 3
Coverable lines: 27
Total lines: 436
Line coverage: 88.8%
Branch coverage
100%
Covered branches: 6
Total branches: 6
Branch coverage: 100%
Method coverage
100%
Covered methods: 4
Fully covered methods: 3
Total methods: 4
Method coverage: 100%
Full method coverage: 75%

Metrics

MethodBranch coverage Crap Score Cyclomatic complexity Line coverage
get_Socket()100%11100%
get_SslStream()100%11100%
.ctor(...)100%22100%
ConnectAsyncCore()100%4485%

File(s)

/home/runner/work/icerpc-csharp/icerpc-csharp/src/IceRpc/Transports/Tcp/Internal/TcpConnection.cs

#LineLine coverage
 1// Copyright (c) ZeroC, Inc.
 2
 3using System.Buffers;
 4using System.Diagnostics;
 5using System.Net;
 6using System.Net.Security;
 7using System.Net.Sockets;
 8using System.Runtime.InteropServices;
 9
 10namespace IceRpc.Transports.Tcp.Internal;
 11
 12/// <summary>Implements <see cref="IDuplexConnection" /> for tcp with or without TLS.</summary>
 13/// <remarks>Unlike Coloc, the Tcp transport is not a "checked" transport, which means it does not need to detect
 14/// violations of the duplex transport contract or report such violations. It assumes its clients are sufficiently well
 15/// tested to never violate this contract. As a result, this implementation does not throw
 16/// <see cref="InvalidOperationException" />.</remarks>
 17internal abstract class TcpConnection : IDuplexConnection
 18{
 19    internal abstract Socket Socket { get; }
 20
 21    internal abstract SslStream? SslStream { get; }
 22
 23    private protected volatile bool _isDisposed;
 24
 25    // The MaxDataSize of the SSL implementation.
 26    private const int MaxSslDataSize = 16 * 1024;
 27
 28    private bool _isShutdown;
 29    private readonly int _maxSslBufferSize;
 30    private readonly List<ArraySegment<byte>> _segments = new();
 31    private readonly IMemoryOwner<byte>? _writeBufferOwner;
 32
 33    public Task<TransportConnectionInformation> ConnectAsync(CancellationToken cancellationToken)
 34    {
 35        ObjectDisposedException.ThrowIf(_isDisposed, this);
 36        return ConnectAsyncCore(cancellationToken);
 37    }
 38
 39    public void Dispose()
 40    {
 41        _isDisposed = true;
 42
 43        if (SslStream is SslStream sslStream)
 44        {
 45            sslStream.Dispose();
 46        }
 47
 48        // If shutdown was called, we can just dispose the connection to complete the graceful TCP closure. Otherwise,
 49        // we abort the TCP connection to ensure the connection doesn't end up in the TIME_WAIT state.
 50        if (_isShutdown)
 51        {
 52            Socket.Dispose();
 53        }
 54        else
 55        {
 56            Socket.Close(0);
 57        }
 58        _writeBufferOwner?.Dispose();
 59    }
 60
 61    public ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken)
 62    {
 63        ObjectDisposedException.ThrowIf(_isDisposed, this);
 64
 65        return buffer.Length > 0 ? PerformReadAsync() :
 66            throw new ArgumentException($"The {nameof(buffer)} cannot be empty.", nameof(buffer));
 67
 68        async ValueTask<int> PerformReadAsync()
 69        {
 70            try
 71            {
 72                return SslStream is SslStream sslStream ?
 73                    await SslStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false) :
 74                    await Socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false);
 75            }
 76            catch (IOException exception)
 77            {
 78                throw exception.ToIceRpcException();
 79            }
 80            catch (SocketException exception)
 81            {
 82                throw exception.ToIceRpcException();
 83            }
 84        }
 85    }
 86
 87    public Task ShutdownWriteAsync(CancellationToken cancellationToken)
 88    {
 89        ObjectDisposedException.ThrowIf(_isDisposed, this);
 90
 91        return PerformShutdownAsync();
 92
 93        async Task PerformShutdownAsync()
 94        {
 95            try
 96            {
 97                if (SslStream is SslStream sslStream)
 98                {
 99                    Task shutdownTask = sslStream.ShutdownAsync();
 100
 101                    try
 102                    {
 103                        await shutdownTask.WaitAsync(cancellationToken).ConfigureAwait(false);
 104                    }
 105                    catch (OperationCanceledException)
 106                    {
 107                        await AbortAndObserveAsync(shutdownTask).ConfigureAwait(false);
 108                        throw;
 109                    }
 110                }
 111
 112                // Shutdown the socket send side to send a TCP FIN packet. We don't close the read side because we want
 113                // to be notified when the peer shuts down it's side of the socket (through the ReceiveAsync call).
 114                Socket.Shutdown(SocketShutdown.Send);
 115
 116                // If shutdown is successful mark the connection as shutdown to ensure Dispose won't reset the TCP
 117                // connection.
 118                _isShutdown = true;
 119            }
 120            catch (IOException exception)
 121            {
 122                throw exception.ToIceRpcException();
 123            }
 124            catch (SocketException exception)
 125            {
 126                throw exception.ToIceRpcException();
 127            }
 128        }
 129    }
 130
 131    public ValueTask WriteAsync(ReadOnlySequence<byte> buffer, CancellationToken cancellationToken)
 132    {
 133        ObjectDisposedException.ThrowIf(_isDisposed, this);
 134        return PerformWriteAsync();
 135
 136        async ValueTask PerformWriteAsync()
 137        {
 138            try
 139            {
 140                if (SslStream is SslStream sslStream)
 141                {
 142                    if (buffer.IsSingleSegment)
 143                    {
 144                        await sslStream.WriteAsync(buffer.First, cancellationToken).ConfigureAwait(false);
 145                    }
 146                    else
 147                    {
 148                        // Coalesce leading segments up to _maxSslBufferSize. We don't coalesce trailing segments as we
 149                        // assume these segments are large enough.
 150                        int leadingSize = 0;
 151                        int leadingSegmentCount = 0;
 152                        foreach (ReadOnlyMemory<byte> memory in buffer)
 153                        {
 154                            if (leadingSize + memory.Length <= _maxSslBufferSize)
 155                            {
 156                                leadingSize += memory.Length;
 157                                leadingSegmentCount++;
 158                            }
 159                            else
 160                            {
 161                                break;
 162                            }
 163                        }
 164
 165                        if (leadingSegmentCount > 1)
 166                        {
 167                            ReadOnlySequence<byte> leading = buffer.Slice(0, leadingSize);
 168                            buffer = buffer.Slice(leadingSize); // buffer can become empty
 169
 170                            Debug.Assert(_writeBufferOwner is not null);
 171                            Memory<byte> writeBuffer = _writeBufferOwner.Memory[0..leadingSize];
 172                            leading.CopyTo(writeBuffer.Span);
 173
 174                            // Send the "coalesced" leading segments
 175                            await sslStream.WriteAsync(writeBuffer, cancellationToken).ConfigureAwait(false);
 176                        }
 177                        // else no need to coalesce (copy) a single segment
 178
 179                        // Send the remaining segments one by one
 180                        if (buffer.IsEmpty)
 181                        {
 182                            // done
 183                        }
 184                        else if (buffer.IsSingleSegment)
 185                        {
 186                            await sslStream.WriteAsync(buffer.First, cancellationToken).ConfigureAwait(false);
 187                        }
 188                        else
 189                        {
 190                            foreach (ReadOnlyMemory<byte> memory in buffer)
 191                            {
 192                                await sslStream.WriteAsync(memory, cancellationToken).ConfigureAwait(false);
 193                            }
 194                        }
 195                    }
 196                }
 197                else
 198                {
 199                    if (buffer.IsSingleSegment)
 200                    {
 201                        int bytesSent = await Socket.SendAsync(
 202                            buffer.First,
 203                            SocketFlags.None,
 204                            cancellationToken).ConfigureAwait(false);
 205
 206                        if (bytesSent != buffer.First.Length)
 207                        {
 208                            // This should never happen.
 209                            throw new IceRpcException(
 210                                IceRpcError.IceRpcError,
 211                                $"Short write on TCP socket: expected {buffer.First.Length} bytes but sent {bytesSent}."
 212                        }
 213                    }
 214                    else
 215                    {
 216                        _segments.Clear();
 217                        long totalBytes = buffer.Length;
 218                        foreach (ReadOnlyMemory<byte> memory in buffer)
 219                        {
 220                            if (MemoryMarshal.TryGetArray(memory, out ArraySegment<byte> segment))
 221                            {
 222                                _segments.Add(segment);
 223                            }
 224                            else
 225                            {
 226                                throw new ArgumentException(
 227                                    $"The {nameof(buffer)} must be backed by arrays.",
 228                                    nameof(buffer));
 229                            }
 230                        }
 231
 232                        Task<int> sendTask = Socket.SendAsync(_segments, SocketFlags.None);
 233
 234                        int bytesSent;
 235                        try
 236                        {
 237                            bytesSent = await sendTask.WaitAsync(cancellationToken).ConfigureAwait(false);
 238                        }
 239                        catch (OperationCanceledException)
 240                        {
 241                            await AbortAndObserveAsync(sendTask).ConfigureAwait(false);
 242                            throw;
 243                        }
 244
 245                        if (bytesSent != totalBytes)
 246                        {
 247                            // This should never happen.
 248                            throw new IceRpcException(
 249                                IceRpcError.IceRpcError,
 250                                $"Short write on TCP socket: expected {totalBytes} bytes but sent {bytesSent}.");
 251                        }
 252                    }
 253                }
 254            }
 255            catch (IOException exception)
 256            {
 257                throw exception.ToIceRpcException();
 258            }
 259            catch (SocketException exception)
 260            {
 261                throw exception.ToIceRpcException();
 262            }
 263        }
 264    }
 265
 266    private protected TcpConnection(IMemoryOwner<byte>? memoryOwner)
 267    {
 268        _writeBufferOwner = memoryOwner;
 269        // When coalescing leading buffers in WriteAsync (SSL only), the upper size limit is the lesser of the size of
 270        // the buffer we rented from the memory pool (typically 4K) and MaxSslDataSize (16K).
 271        _maxSslBufferSize = Math.Min(memoryOwner?.Memory.Length ?? 0, MaxSslDataSize);
 272    }
 273
 274    private protected abstract Task<TransportConnectionInformation> ConnectAsyncCore(
 275        CancellationToken cancellationToken);
 276
 277    /// <summary>Aborts the connection and then observes the exception of the provided task.</summary>
 278    private async Task AbortAndObserveAsync(Task task)
 279    {
 280        Socket.Close(0);
 281        try
 282        {
 283            await task.ConfigureAwait(false);
 284        }
 285        catch
 286        {
 287            // observe exception
 288        }
 289    }
 290}
 291
 292internal class TcpClientConnection : TcpConnection
 293{
 294    internal override Socket Socket { get; }
 295
 296    internal override SslStream? SslStream => _sslStream;
 297
 298    private readonly EndPoint _address;
 299    private readonly SslClientAuthenticationOptions? _authenticationOptions;
 300
 301    private SslStream? _sslStream;
 302
 303    internal TcpClientConnection(
 304        TransportAddress transportAddress,
 305        SslClientAuthenticationOptions? authenticationOptions,
 306        MemoryPool<byte> pool,
 307        int minimumSegmentSize,
 308        TcpClientTransportOptions options)
 309        : base(authenticationOptions is not null ? pool.Rent(minimumSegmentSize) : null)
 310    {
 311        _address = IPAddress.TryParse(transportAddress.Host, out IPAddress? ipAddress) ?
 312            new IPEndPoint(ipAddress, transportAddress.Port) :
 313            new DnsEndPoint(transportAddress.Host, transportAddress.Port);
 314
 315        _authenticationOptions = authenticationOptions;
 316
 317        // When using IPv6 address family we use the socket constructor without AddressFamily parameter to ensure
 318        // dual-mode socket are used in platforms that support them.
 319        Socket = ipAddress?.AddressFamily == AddressFamily.InterNetwork ?
 320            new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp) :
 321            new Socket(SocketType.Stream, ProtocolType.Tcp);
 322
 323        try
 324        {
 325            if (options.LocalNetworkAddress is IPEndPoint localNetworkAddress)
 326            {
 327                Socket.Bind(localNetworkAddress);
 328            }
 329
 330            Socket.Configure(options);
 331        }
 332        catch (SocketException exception)
 333        {
 334            Socket.Dispose();
 335            throw exception.ToIceRpcException();
 336        }
 337        catch
 338        {
 339            Socket.Dispose();
 340            throw;
 341        }
 342    }
 343
 344    private protected override async Task<TransportConnectionInformation> ConnectAsyncCore(
 345        CancellationToken cancellationToken)
 346    {
 347        bool isConnected = false;
 348        try
 349        {
 350            Debug.Assert(Socket is not null);
 351
 352            // Connect to the peer.
 353            await Socket.ConnectAsync(_address, cancellationToken).ConfigureAwait(false);
 354            isConnected = true;
 355
 356            if (_authenticationOptions is not null)
 357            {
 358                _sslStream = new SslStream(new NetworkStream(Socket, false), false);
 359
 360                await _sslStream.AuthenticateAsClientAsync(
 361                    _authenticationOptions,
 362                    cancellationToken).ConfigureAwait(false);
 363            }
 364
 365            return new TransportConnectionInformation(
 366                localNetworkAddress: Socket.LocalEndPoint!,
 367                remoteNetworkAddress: Socket.RemoteEndPoint!,
 368                _sslStream?.RemoteCertificate);
 369        }
 370        catch (IOException exception)
 371        {
 372            throw exception.ToIceRpcException();
 373        }
 374        catch (SocketException exception) when (isConnected)
 375        {
 376            // This can happen if the peer closes the connection immediately after accepting it, which can cause the
 377            // endpoint information to be unavailable. Any SocketException at this point means the connection is no
 378            // longer usable.
 379            throw new IceRpcException(IceRpcError.ConnectionAborted, exception);
 380        }
 381        catch (SocketException exception)
 382        {
 383            throw exception.ToIceRpcException();
 384        }
 385    }
 386}
 387
 388internal class TcpServerConnection : TcpConnection
 389{
 6562390    internal override Socket Socket { get; }
 391
 12525392    internal override SslStream? SslStream => _sslStream;
 393
 394    private readonly SslServerAuthenticationOptions? _authenticationOptions;
 395    private SslStream? _sslStream;
 396
 397    internal TcpServerConnection(
 398        Socket socket,
 399        SslServerAuthenticationOptions? authenticationOptions,
 400        MemoryPool<byte> pool,
 401        int minimumSegmentSize)
 92402        : base(authenticationOptions is not null ? pool.Rent(minimumSegmentSize) : null)
 92403    {
 92404        Socket = socket;
 92405        _authenticationOptions = authenticationOptions;
 92406    }
 407
 408    private protected override async Task<TransportConnectionInformation> ConnectAsyncCore(
 409        CancellationToken cancellationToken)
 88410    {
 411        try
 88412        {
 88413            if (_authenticationOptions is not null)
 29414            {
 415                // This can only be created with a connected socket.
 29416                _sslStream = new SslStream(new NetworkStream(Socket, false), false);
 29417                await _sslStream.AuthenticateAsServerAsync(
 29418                    _authenticationOptions,
 29419                    cancellationToken).ConfigureAwait(false);
 24420            }
 421
 83422            return new TransportConnectionInformation(
 83423                localNetworkAddress: Socket.LocalEndPoint!,
 83424                remoteNetworkAddress: Socket.RemoteEndPoint!,
 83425                _sslStream?.RemoteCertificate);
 426        }
 2427        catch (IOException exception)
 2428        {
 2429            throw exception.ToIceRpcException();
 430        }
 0431        catch (SocketException exception)
 0432        {
 0433            throw exception.ToIceRpcException();
 434        }
 83435    }
 436}