| | 1 | | // Copyright (c) ZeroC, Inc. |
| | 2 | |
|
| | 3 | | using Google.Protobuf; |
| | 4 | | using System.Buffers; |
| | 5 | | using System.Buffers.Binary; |
| | 6 | | using System.Diagnostics; |
| | 7 | | using System.IO.Pipelines; |
| | 8 | | using System.Runtime.CompilerServices; |
| | 9 | |
|
| | 10 | | namespace IceRpc.Protobuf.Internal; |
| | 11 | |
|
| | 12 | | /// <summary>Provides extension methods for <see cref="PipeReader" />.</summary> |
| | 13 | | internal static class PipeReaderExtensions |
| | 14 | | { |
| | 15 | | /// <summary>Decodes a Protobuf length prefixed message from a <see cref="PipeReader" />.</summary> |
| | 16 | | /// <param name="reader">The <see cref="PipeReader" /> containing the Protobuf length prefixed message.</param> |
| | 17 | | /// <param name="parser">The <see cref="MessageParser{T}" /> used to parse the message data.</param> |
| | 18 | | /// <param name="maxMessageLength">The maximum allowed length.</param> |
| | 19 | | /// <param name="cancellationToken">A cancellation token that receives the cancellation requests.</param> |
| | 20 | | /// <returns>The decoded message object.</returns> |
| | 21 | | internal static async ValueTask<T> DecodeProtobufMessageAsync<T>( |
| | 22 | | this PipeReader reader, |
| | 23 | | MessageParser<T> parser, |
| | 24 | | int maxMessageLength, |
| | 25 | | CancellationToken cancellationToken) where T : IMessage<T> |
| 27 | 26 | | { |
| 27 | 27 | | T? message = await ReadMessageAsync( |
| 27 | 28 | | reader, |
| 27 | 29 | | parser, |
| 27 | 30 | | maxMessageLength, |
| 27 | 31 | | cancellationToken).ConfigureAwait(false); |
| | 32 | |
|
| 26 | 33 | | Debug.Assert(message is not null); |
| 26 | 34 | | return message; |
| 26 | 35 | | } |
| | 36 | |
|
| | 37 | | /// <summary>Decodes an async enumerable from a pipe reader.</summary> |
| | 38 | | /// <param name="reader">The pipe reader.</param> |
| | 39 | | /// <param name="messageParser">The <see cref="MessageParser{T}" /> used to parse the message data.</param> |
| | 40 | | /// <param name="maxMessageLength">The maximum allowed length.</param> |
| | 41 | | /// <param name="cancellationToken">The cancellation token which is provided to <see |
| | 42 | | /// cref="IAsyncEnumerable{T}.GetAsyncEnumerator(CancellationToken)" />.</param> |
| | 43 | | internal static async IAsyncEnumerable<T> ToAsyncEnumerable<T>( |
| | 44 | | this PipeReader reader, |
| | 45 | | MessageParser<T> messageParser, |
| | 46 | | int maxMessageLength, |
| | 47 | | [EnumeratorCancellation] CancellationToken cancellationToken = default) where T : IMessage<T> |
| 23 | 48 | | { |
| | 49 | | try |
| 23 | 50 | | { |
| 65803 | 51 | | while (true) |
| 65803 | 52 | | { |
| 65803 | 53 | | if (cancellationToken.IsCancellationRequested) |
| 1 | 54 | | { |
| 1 | 55 | | yield break; |
| | 56 | | } |
| | 57 | |
|
| | 58 | | T? message; |
| | 59 | | try |
| 65802 | 60 | | { |
| 65802 | 61 | | message = await ReadMessageAsync( |
| 65802 | 62 | | reader, |
| 65802 | 63 | | messageParser, |
| 65802 | 64 | | maxMessageLength, |
| 65802 | 65 | | cancellationToken).ConfigureAwait(false); |
| 65800 | 66 | | } |
| 1 | 67 | | catch (OperationCanceledException exception) when (exception.CancellationToken == cancellationToken) |
| 1 | 68 | | { |
| | 69 | | // Canceling the cancellation token is a normal way to complete an iteration. |
| 1 | 70 | | yield break; |
| | 71 | | } |
| | 72 | |
|
| 65800 | 73 | | if (message is null) |
| 19 | 74 | | { |
| 19 | 75 | | yield break; |
| | 76 | | } |
| 65781 | 77 | | yield return message; |
| 65780 | 78 | | } |
| | 79 | | } |
| | 80 | | finally |
| 23 | 81 | | { |
| 23 | 82 | | reader.Complete(); |
| 23 | 83 | | } |
| 22 | 84 | | } |
| | 85 | |
|
| | 86 | | private static async ValueTask<T?> ReadMessageAsync<T>( |
| | 87 | | PipeReader reader, |
| | 88 | | MessageParser<T> messageParser, |
| | 89 | | int maxMessageLength, |
| | 90 | | CancellationToken cancellationToken) where T : IMessage<T> |
| 65829 | 91 | | { |
| 65829 | 92 | | ReadResult readResult = await reader.ReadAtLeastAsync(5, cancellationToken).ConfigureAwait(false); |
| | 93 | | // We never call CancelPendingRead; an interceptor or middleware can but it's not correct. |
| 65828 | 94 | | if (readResult.IsCanceled) |
| 0 | 95 | | { |
| 0 | 96 | | throw new InvalidOperationException("Unexpected call to CancelPendingRead."); |
| | 97 | | } |
| | 98 | |
|
| 65828 | 99 | | if (readResult.Buffer.IsEmpty) |
| 19 | 100 | | { |
| 19 | 101 | | return default; |
| | 102 | | } |
| | 103 | |
|
| 65809 | 104 | | if (readResult.Buffer.Length < 5) |
| 1 | 105 | | { |
| 1 | 106 | | throw new InvalidDataException( |
| 1 | 107 | | $"The payload has {readResult.Buffer.Length} bytes, but 5 bytes were expected."); |
| | 108 | | } |
| | 109 | |
|
| 65808 | 110 | | if (readResult.Buffer.FirstSpan[0] == 1) |
| 0 | 111 | | { |
| 0 | 112 | | throw new NotSupportedException("IceRPC does not support Protobuf compressed messages"); |
| | 113 | | } |
| 65808 | 114 | | int messageLength = DecodeMessageLength(readResult.Buffer.Slice(1, 4)); |
| 65808 | 115 | | reader.AdvanceTo(readResult.Buffer.GetPosition(5)); |
| 65808 | 116 | | if (messageLength >= maxMessageLength) |
| 1 | 117 | | { |
| 1 | 118 | | throw new InvalidDataException("The message length exceeds the maximum value."); |
| | 119 | | } |
| | 120 | |
|
| 65807 | 121 | | readResult = await reader.ReadAtLeastAsync(messageLength, cancellationToken).ConfigureAwait(false); |
| | 122 | | // We never call CancelPendingRead; an interceptor or middleware can but it's not correct. |
| 65807 | 123 | | if (readResult.IsCanceled) |
| 0 | 124 | | { |
| 0 | 125 | | throw new InvalidOperationException("Unexpected call to CancelPendingRead."); |
| | 126 | | } |
| | 127 | |
|
| 65807 | 128 | | if (readResult.Buffer.Length < messageLength) |
| 0 | 129 | | { |
| 0 | 130 | | throw new InvalidDataException( |
| 0 | 131 | | $"The payload has {readResult.Buffer.Length} bytes, but {messageLength} bytes were expected."); |
| | 132 | | } |
| | 133 | |
|
| | 134 | | // TODO: Does ParseFrom check it read all the bytes? |
| 65807 | 135 | | T message = messageParser.ParseFrom(readResult.Buffer.Slice(0, messageLength)); |
| 65807 | 136 | | reader.AdvanceTo(readResult.Buffer.GetPosition(messageLength)); |
| 65807 | 137 | | return message; |
| | 138 | |
|
| | 139 | | static int DecodeMessageLength(ReadOnlySequence<byte> buffer) |
| 65808 | 140 | | { |
| 65808 | 141 | | Debug.Assert(buffer.Length == 4); |
| 65808 | 142 | | Span<byte> spanBuffer = stackalloc byte[4]; |
| 65808 | 143 | | buffer.CopyTo(spanBuffer); |
| 65808 | 144 | | return BinaryPrimitives.ReadInt32BigEndian(spanBuffer); |
| 65808 | 145 | | } |
| 65826 | 146 | | } |
| | 147 | | } |