| | | 1 | | // Copyright (c) ZeroC, Inc. |
| | | 2 | | |
| | | 3 | | using Google.Protobuf; |
| | | 4 | | using System.Diagnostics; |
| | | 5 | | using System.IO.Pipelines; |
| | | 6 | | using System.Runtime.CompilerServices; |
| | | 7 | | |
| | | 8 | | namespace IceRpc.Protobuf.RpcMethods.Internal; |
| | | 9 | | |
| | | 10 | | /// <summary>The default <see cref="IAsyncStream{T}" /> implementation. It wraps a <see cref="PipeReader" /> and |
| | | 11 | | /// decodes its bytes into Protobuf messages of type <typeparamref name="T"/>.</summary> |
| | | 12 | | internal sealed class AsyncStream<T> : IAsyncStream<T> where T : class, IMessage<T> |
| | | 13 | | { |
| | | 14 | | private readonly PipeReader _reader; |
| | | 15 | | private readonly MessageParser<T> _messageParser; |
| | | 16 | | private readonly int _maxMessageLength; |
| | | 17 | | |
| | | 18 | | // Canceled by Dispose when iteration has started, to unblock any pending ReadAsync. |
| | 232 | 19 | | private readonly CancellationTokenSource _disposeCts = new(); |
| | | 20 | | |
| | | 21 | | // Set when GetAsyncEnumerator is called. This enforces the single-enumerator contract even if the created |
| | | 22 | | // enumerator is never advanced. |
| | | 23 | | private bool _enumeratorCreated; |
| | | 24 | | |
| | | 25 | | // Atomic state used to safely arbitrate ownership of _reader.Complete() between Dispose and the first |
| | | 26 | | // MoveNextAsync. |
| | | 27 | | private int _state; |
| | | 28 | | |
| | | 29 | | public void Dispose() |
| | 226 | 30 | | { |
| | 226 | 31 | | int original = Interlocked.Exchange(ref _state, (int)State.Disposed); |
| | | 32 | | |
| | 226 | 33 | | switch ((State)original) |
| | | 34 | | { |
| | | 35 | | case State.Initial: |
| | | 36 | | // No iteration could have started (and any future MoveNextAsync will see Disposed and throw). |
| | | 37 | | // Safe to complete the reader directly from this thread. |
| | 101 | 38 | | _reader.Complete(); |
| | 101 | 39 | | _disposeCts.Dispose(); |
| | 101 | 40 | | break; |
| | | 41 | | |
| | | 42 | | case State.Iterating: |
| | | 43 | | // The iterator owns the reader; its finally will complete it. We only signal cancellation here. |
| | | 44 | | // We must not dispose _disposeCts here: a linked CTS inside the iterator may still hold a |
| | | 45 | | // registration on _disposeCts.Token. |
| | 123 | 46 | | _disposeCts.Cancel(); |
| | 123 | 47 | | break; |
| | | 48 | | |
| | | 49 | | case State.Disposed: |
| | | 50 | | // no-op (Dispose called more than once). |
| | 2 | 51 | | break; |
| | | 52 | | } |
| | 226 | 53 | | } |
| | | 54 | | |
| | | 55 | | public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken) |
| | 231 | 56 | | { |
| | | 57 | | // We don't check for Disposed here: if the stream was disposed, the first MoveNextAsync call on the |
| | | 58 | | // returned enumerator throws ObjectDisposedException (see EnumerateAsync). |
| | 231 | 59 | | if (_enumeratorCreated) |
| | 1 | 60 | | { |
| | 1 | 61 | | throw new InvalidOperationException($"An {nameof(IAsyncStream<T>)} can only be enumerated once."); |
| | | 62 | | } |
| | 230 | 63 | | _enumeratorCreated = true; |
| | 230 | 64 | | return EnumerateAsync(cancellationToken).GetAsyncEnumerator(cancellationToken); |
| | 230 | 65 | | } |
| | | 66 | | |
| | 232 | 67 | | internal AsyncStream(PipeReader reader, MessageParser<T> messageParser, int maxMessageLength) |
| | 232 | 68 | | { |
| | 232 | 69 | | _reader = reader; |
| | 232 | 70 | | _messageParser = messageParser; |
| | 232 | 71 | | _maxMessageLength = maxMessageLength; |
| | 232 | 72 | | } |
| | | 73 | | |
| | | 74 | | private async IAsyncEnumerable<T> EnumerateAsync([EnumeratorCancellation] CancellationToken cancellationToken) |
| | 228 | 75 | | { |
| | | 76 | | // Because this async method returns an IAsyncEnumerable<T>, it only starts executing when the caller starts |
| | | 77 | | // iterating (calls MoveNextAsync on the enumerator). It does not execute when EnumerateAsync is called, or |
| | | 78 | | // even when GetAsyncEnumerator is called on the returned IAsyncEnumerable<T>. |
| | | 79 | | |
| | | 80 | | // Atomically claim the reader (Idle -> Iterating). This races with Dispose's atomic transition to Disposed; |
| | | 81 | | // whichever transition wins from Idle owns _reader.Complete(). |
| | 228 | 82 | | int original = Interlocked.CompareExchange(ref _state, (int)State.Iterating, (int)State.Initial); |
| | 228 | 83 | | ObjectDisposedException.ThrowIf(original == (int)State.Disposed, this); |
| | 131 | 84 | | Debug.Assert(original == (int)State.Initial); // _enumeratorCreated forbids a second iteration. |
| | | 85 | | |
| | | 86 | | // Link the caller-provided token with our internal dispose token so that Dispose can unblock ReadAsync. |
| | 131 | 87 | | using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource( |
| | 131 | 88 | | cancellationToken, |
| | 131 | 89 | | _disposeCts.Token); |
| | 131 | 90 | | CancellationToken linkedToken = linkedCts.Token; |
| | | 91 | | |
| | | 92 | | try |
| | 131 | 93 | | { |
| | 65914 | 94 | | while (true) |
| | 65914 | 95 | | { |
| | | 96 | | T? message; |
| | | 97 | | try |
| | 65914 | 98 | | { |
| | 65914 | 99 | | message = await _reader.ReadProtobufMessageAsync( |
| | 65914 | 100 | | _messageParser, |
| | 65914 | 101 | | _maxMessageLength, |
| | 65914 | 102 | | linkedToken).ConfigureAwait(false); |
| | 65805 | 103 | | } |
| | 108 | 104 | | catch (OperationCanceledException) when (linkedToken.IsCancellationRequested) |
| | 108 | 105 | | { |
| | | 106 | | // Re-issue the cancellation with the caller's token so the OCE that propagates carries the |
| | | 107 | | // token the caller passed in (not our internal linkedToken). When dispose is the only source, |
| | | 108 | | // surface dispose-mid-iteration as ObjectDisposedException. |
| | 108 | 109 | | cancellationToken.ThrowIfCancellationRequested(); |
| | | 110 | | |
| | | 111 | | // Safe to read _state without a barrier: Dispose writes State.Disposed before calling |
| | | 112 | | // _disposeCts.Cancel(), and observing the cancellation here establishes happens-before |
| | | 113 | | // with that write. |
| | 105 | 114 | | Debug.Assert(_state == (int)State.Disposed); |
| | 105 | 115 | | throw new ObjectDisposedException(nameof(AsyncStream<>), "The stream was disposed while reading."); |
| | | 116 | | } |
| | | 117 | | |
| | 65805 | 118 | | if (message is null) |
| | 20 | 119 | | { |
| | 20 | 120 | | yield break; |
| | | 121 | | } |
| | 65785 | 122 | | yield return message; |
| | 65783 | 123 | | } |
| | | 124 | | } |
| | | 125 | | finally |
| | 131 | 126 | | { |
| | 131 | 127 | | _reader.Complete(); |
| | 131 | 128 | | } |
| | 22 | 129 | | } |
| | | 130 | | |
| | | 131 | | private enum State |
| | | 132 | | { |
| | | 133 | | Initial = 0, |
| | | 134 | | Iterating = 1, |
| | | 135 | | Disposed = 2 |
| | | 136 | | } |
| | | 137 | | } |