aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvan Lozano <ivanlozano@google.com>2023-10-19 00:28:46 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2023-10-19 00:28:46 +0000
commit41a562b84f095d95cffd41b24a0037bc3122fac5 (patch)
treee7e603c24ee7d164050a19a1681ceb72d551e651
parent1e6785177a4326f8b0e2eabebdb0e9ea369c620d (diff)
parent0a41a4b2061c07700ff50eb2d9164e1637695d94 (diff)
downloadgrpcio-41a562b84f095d95cffd41b24a0037bc3122fac5.tar.gz
Update grpcio crate to 0.12.1 am: e55ac28d76 am: 0a41a4b206
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/grpcio/+/2663016 Change-Id: I3cc4cae42d3837134a3214fd501fd3935ecf3417 Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r--Android.bp7
-rw-r--r--CHANGELOG.md46
-rw-r--r--Cargo.toml71
-rw-r--r--Cargo.toml.orig24
-rw-r--r--METADATA8
-rw-r--r--README.md11
-rw-r--r--cargo2android.json2
-rw-r--r--src/buf.rs32
-rw-r--r--src/call/client.rs244
-rw-r--r--src/call/mod.rs148
-rw-r--r--src/call/server.rs47
-rw-r--r--src/channel.rs155
-rw-r--r--src/channelz.rs133
-rw-r--r--src/client.rs15
-rw-r--r--src/codec.rs4
-rw-r--r--src/env.rs2
-rw-r--r--src/error.rs5
-rw-r--r--src/lib.rs13
-rw-r--r--src/log_util.rs2
-rw-r--r--src/metadata.rs57
-rw-r--r--src/quota.rs9
-rw-r--r--src/security/credentials.rs78
-rw-r--r--src/security/mod.rs77
-rw-r--r--src/server.rs265
-rw-r--r--src/task/executor.rs5
-rw-r--r--src/task/mod.rs13
-rw-r--r--src/task/promise.rs52
27 files changed, 1026 insertions, 499 deletions
diff --git a/Android.bp b/Android.bp
index 66a47a1..a8a507d 100644
--- a/Android.bp
+++ b/Android.bp
@@ -24,15 +24,18 @@ rust_library {
host_supported: true,
crate_name: "grpcio",
cargo_env_compat: true,
- cargo_pkg_version: "0.9.1",
+ cargo_pkg_version: "0.12.1",
srcs: ["src/lib.rs"],
edition: "2018",
features: [
+ "_secure",
+ "boringssl",
"protobuf",
"protobuf-codec",
],
rustlibs: [
- "libfutures",
+ "libfutures_util",
+ "libfutures_executor",
"libgrpcio_sys",
"liblibc",
"liblog_rust",
diff --git a/CHANGELOG.md b/CHANGELOG.md
index aa28773..2899f77 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,49 @@
+# 0.12.1 - 2023-02-14
+
+- Provide more debug info for RpcStatus (#603)
+- Compile on latest stable (#605)
+- Allow accessing grpcio client and channel (#597)
+
+# 0.12.0 - 2022-11-04
+
+- Update prost to 0.11 (#595)
+- Update grpc to 1.36.5 (#593)
+- Update the security API of channel to match the C++ version (#593)
+
+Note, 1.36.5 removes the support of epollex engine (which is the default engine before),
+and enables transparent retry by default. So you may experience potential performance regression.
+And 1.36.5 is also the last version that supports C++11. Next version will requires C++14.
+
+# 0.11.0 - 2022-09-10
+
+- Update prost to 0.10 (#582)
+
+# 0.10.3 - 2022-06-27
+
+- Add support for GRPC_ARG_ENABLE_HTTP_PROXY parameter (#575)
+- Support setting gzip level (#577)
+
+# 0.10.2 - 2022-04-15
+
+- Make `ResourceQuota` cloneable (#568)
+- Allow use local subchannel pool (#565)
+
+# 0.10.1 - 2022-03-28
+
+- Fix potential UAF and double free (#566)
+
+# 0.10.0 - 2022-03-02
+
+- Update prost to 0.9.0 (#544) (#559)
+- Make `CallOption` sync (#551)
+- Update grpc c core to 1.44.0 (#549) (#558)
+- Support querying channelz by API (#550)
+- Reduce dependency on future crate (#554)
+- Support headers on all call types (#555)
+- Rename features "secure" to "boringssl" (#558)
+- Drop dependency on bindgen for both MacOS and x86_64/aarch64 Linux (#558)
+- Make health crate not depend on secure feature (#558)
+
# 0.9.1 - 2021-09-18
- Make boringssl-src optional (#537)
diff --git a/Cargo.toml b/Cargo.toml
index fd9e71f..83a57dd 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -3,38 +3,56 @@
# When uploading crates to the registry Cargo will automatically
# "normalize" Cargo.toml files for maximal compatibility
# with all versions of Cargo and also rewrite `path` dependencies
-# to registry (e.g., crates.io) dependencies
+# to registry (e.g., crates.io) dependencies.
#
-# If you believe there's an error in this file please file an
-# issue against the rust-lang/cargo repository. If you're
-# editing this file be aware that the upstream Cargo.toml
-# will likely look very different (and much more reasonable)
+# If you are reading this file be aware that the original Cargo.toml
+# will likely look very different (and much more reasonable).
+# See Cargo.toml.orig for the original contents.
[package]
edition = "2018"
name = "grpcio"
-version = "0.9.1"
+version = "0.12.1"
authors = ["The TiKV Project Developers"]
autoexamples = false
description = "The rust language implementation of gRPC, base on the gRPC c core library."
homepage = "https://github.com/tikv/grpc-rs"
documentation = "https://docs.rs/grpcio"
readme = "README.md"
-keywords = ["grpc", "protobuf", "rpc", "tls", "http2"]
-categories = ["asynchronous", "network-programming"]
+keywords = [
+ "grpc",
+ "protobuf",
+ "rpc",
+ "tls",
+ "http2",
+]
+categories = [
+ "asynchronous",
+ "network-programming",
+]
license = "Apache-2.0"
repository = "https://github.com/tikv/grpc-rs"
+
[package.metadata.docs.rs]
all-features = true
+
[dependencies.bytes]
version = "1.0"
optional = true
-[dependencies.futures]
+[dependencies.futures-executor]
+version = "0.3"
+
+[dependencies.futures-util]
version = "0.3"
+features = [
+ "std",
+ "sink",
+]
+default-features = false
[dependencies.grpcio-sys]
-version = "0.9"
+version = "0.12.1"
default-features = false
[dependencies.libc]
@@ -44,10 +62,10 @@ version = "0.2"
version = "0.4"
[dependencies.parking_lot]
-version = "0.11"
+version = "0.12"
[dependencies.prost]
-version = "0.7"
+version = "0.11"
optional = true
[dependencies.protobuf]
@@ -55,13 +73,30 @@ version = "2.0"
optional = true
[features]
-default = ["protobuf-codec", "secure", "use-bindgen"]
+_secure = []
+boringssl = [
+ "grpcio-sys/boringssl",
+ "_secure",
+]
+default = [
+ "protobuf-codec",
+ "boringssl",
+]
+nightly = []
no-omit-frame-pointer = ["grpcio-sys/no-omit-frame-pointer"]
-openssl = ["secure", "grpcio-sys/openssl"]
-openssl-vendored = ["secure", "grpcio-sys/openssl-vendored"]
-prost-codec = ["prost", "bytes"]
+openssl = [
+ "_secure",
+ "grpcio-sys/openssl",
+]
+openssl-vendored = [
+ "_secure",
+ "grpcio-sys/openssl-vendored",
+]
+prost-codec = [
+ "prost",
+ "bytes",
+]
protobuf-codec = ["protobuf"]
-secure = ["grpcio-sys/secure"]
-use-bindgen = ["grpcio-sys/use-bindgen"]
+
[badges.travis-ci]
repository = "tikv/grpc-rs"
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
index 9755697..68067cf 100644
--- a/Cargo.toml.orig
+++ b/Cargo.toml.orig
@@ -1,6 +1,6 @@
[package]
name = "grpcio"
-version = "0.9.1"
+version = "0.12.1"
edition = "2018"
authors = ["The TiKV Project Developers"]
license = "Apache-2.0"
@@ -17,14 +17,15 @@ autoexamples = false
all-features = true
[dependencies]
-grpcio-sys = { path = "grpc-sys", version = "0.9", default-features = false }
+grpcio-sys = { path = "grpc-sys", version = "0.12.1", default-features = false }
libc = "0.2"
-futures = "0.3"
+futures-executor = "0.3"
+futures-util = { version = "0.3", default-features = false, features = ["std", "sink"] }
protobuf = { version = "2.0", optional = true }
-prost = { version = "0.7", optional = true }
+prost = { version = "0.11", optional = true }
bytes = { version = "1.0", optional = true }
log = "0.4"
-parking_lot = "0.11"
+parking_lot = "0.12"
[workspace]
members = [
@@ -40,17 +41,18 @@ members = [
exclude = ["xtask"]
[features]
-default = ["protobuf-codec", "secure", "use-bindgen"]
+default = ["protobuf-codec", "boringssl"]
+_secure = []
protobuf-codec = ["protobuf"]
prost-codec = ["prost", "bytes"]
-secure = ["grpcio-sys/secure"]
-openssl = ["secure", "grpcio-sys/openssl"]
-openssl-vendored = ["secure", "grpcio-sys/openssl-vendored"]
+nightly = []
+boringssl = ["grpcio-sys/boringssl", "_secure"]
+openssl = ["_secure", "grpcio-sys/openssl"]
+openssl-vendored = ["_secure", "grpcio-sys/openssl-vendored"]
no-omit-frame-pointer = ["grpcio-sys/no-omit-frame-pointer"]
-use-bindgen = ["grpcio-sys/use-bindgen"]
[badges]
travis-ci = { repository = "tikv/grpc-rs" }
[patch.crates-io]
-grpcio-compiler = { path = "compiler", version = "0.9.0", default-features = false }
+grpcio-compiler = { path = "compiler", version = "0.12.1", default-features = false }
diff --git a/METADATA b/METADATA
index 8384cab..3129a8f 100644
--- a/METADATA
+++ b/METADATA
@@ -9,11 +9,11 @@ third_party {
type: ARCHIVE
value: "https://static.crates.io/crates/grpcio/grpcio-0.9.1.crate"
}
- version: "0.9.1"
+ version: "0.12.1"
license_type: NOTICE
last_upgrade_date {
- year: 2021
- month: 9
- day: 22
+ year: 2023
+ month: 6
+ day: 21
}
}
diff --git a/README.md b/README.md
index ca9e9da..7767c62 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,6 @@ This project is still under development. The following features with the check m
- Rust >= 1.36.0
- binutils >= 2.22
- LLVM and Clang >= 3.9 if you need to generate bindings at compile time.
-- By default, the [secure feature](#feature-secure) is provided by boringssl. You can also use openssl instead by enabling [openssl feature](#feature-openssl).
For Linux and MacOS, you also need to install gcc 4.9+ (or clang) too.
@@ -90,17 +89,17 @@ To include this project as a dependency:
```
[dependencies]
-grpcio = "0.6"
+grpcio = "0.12"
```
-### Feature `secure`
+### Feature `boringssl`
-`secure` feature enables support for TLS encryption and some authentication
+`boringssl` feature enables support for TLS encryption and some authentication
mechanism. When you do not need it, for example when working in intranet,
you can disable it by using the following configuration:
```
[dependencies]
-grpcio = { version = "0.6", default-features = false, features = ["protobuf-codec"] }
+grpcio = { version = "0.12", default-features = false, features = ["protobuf-codec"] }
```
### Feature `prost-codec` and `protobuf-codec`
@@ -120,7 +119,7 @@ your `Cargo.toml`'s features list for `gprcio`, which requires openssl (>=1.0.2)
```toml
[dependencies]
-grpcio = { version = "0.6", features = ["openssl"] }
+grpcio = { version = "0.12", features = ["openssl"] }
```
Feature `openssl-vendored` is the same as feature `openssl` except it will build openssl from
diff --git a/cargo2android.json b/cargo2android.json
index 7721876..bc4594d 100644
--- a/cargo2android.json
+++ b/cargo2android.json
@@ -3,7 +3,7 @@
"//apex_available:platform"
],
"device": true,
- "features": "protobuf,protobuf-codec",
+ "features": "_secure,boringssl,protobuf-codec,protobuf",
"min-sdk-version": "29",
"run": true,
"vendor-available": true,
diff --git a/src/buf.rs b/src/buf.rs
index de8fe54..7b5f4d3 100644
--- a/src/buf.rs
+++ b/src/buf.rs
@@ -15,7 +15,7 @@ const INLINED_SIZE: usize = mem::size_of::<libc::size_t>() + mem::size_of::<*mut
/// A convenient rust wrapper for the type `grpc_slice`.
///
/// It's expected that the slice should be initialized.
-#[repr(C)]
+#[repr(transparent)]
pub struct GrpcSlice(grpc_slice);
impl GrpcSlice {
@@ -512,30 +512,18 @@ mod tests {
let mut dest = [0; 7];
let amt = reader.read(&mut dest).unwrap();
- assert_eq!(
- dest[..amt],
- expect[..amt],
- "len: {}, nslice: {}",
- len,
- n_slice
- );
+ assert_eq!(dest[..amt], expect[..amt], "len: {len}, nslice: {n_slice}");
// Read after move.
let mut box_reader = Box::new(reader);
let amt = box_reader.read(&mut dest).unwrap();
- assert_eq!(
- dest[..amt],
- expect[..amt],
- "len: {}, nslice: {}",
- len,
- n_slice
- );
+ assert_eq!(dest[..amt], expect[..amt], "len: {len}, nslice: {n_slice}");
// Test read_to_end.
let mut reader = new_message_reader(source.clone(), n_slice);
let mut dest = vec![];
reader.read_to_end(&mut dest).unwrap();
- assert_eq!(dest, expect, "len: {}, nslice: {}", len, n_slice);
+ assert_eq!(dest, expect, "len: {len}, nslice: {n_slice}");
assert_eq!(0, reader.len());
assert_eq!(0, reader.read(&mut [1]).unwrap());
@@ -545,19 +533,11 @@ mod tests {
reader.consume(source.len() * (n_slice - 1));
let mut dest = vec![];
reader.read_to_end(&mut dest).unwrap();
- assert_eq!(
- dest.len(),
- source.len(),
- "len: {}, nslice: {}",
- len,
- n_slice
- );
+ assert_eq!(dest.len(), source.len(), "len: {len}, nslice: {n_slice}");
assert_eq!(
*dest,
expect[expect.len() - source.len()..],
- "len: {}, nslice: {}",
- len,
- n_slice
+ "len: {len}, nslice: {n_slice}"
);
assert_eq!(0, reader.len());
assert_eq!(0, reader.read(&mut [1]).unwrap());
diff --git a/src/call/client.rs b/src/call/client.rs
index e89862a..a183ad4 100644
--- a/src/call/client.rs
+++ b/src/call/client.rs
@@ -1,17 +1,17 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
+use std::future::Future;
use std::pin::Pin;
use std::ptr;
use std::sync::Arc;
+use std::task::{Context, Poll};
use std::time::Duration;
use crate::grpc_sys;
-use futures::ready;
-use futures::sink::Sink;
-use futures::stream::Stream;
-use futures::task::{Context, Poll};
+use futures_executor::block_on;
+use futures_util::future::poll_fn;
+use futures_util::{ready, Sink, Stream};
use parking_lot::Mutex;
-use std::future::Future;
use super::{ShareCall, ShareCallHolder, SinkBase, WriteFlags};
use crate::buf::GrpcSlice;
@@ -19,7 +19,7 @@ use crate::call::{check_run, Call, MessageReader, Method};
use crate::channel::Channel;
use crate::codec::{DeserializeFn, SerializeFn};
use crate::error::{Error, Result};
-use crate::metadata::Metadata;
+use crate::metadata::{Metadata, UnownedMetadata};
use crate::task::{BatchFuture, BatchType};
/// Update the flag bit in res.
@@ -42,16 +42,6 @@ pub struct CallOption {
}
impl CallOption {
- /// Signal that the call is idempotent.
- pub fn idempotent(mut self, is_idempotent: bool) -> CallOption {
- change_flag(
- &mut self.call_flags,
- grpc_sys::GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST,
- is_idempotent,
- );
- self
- }
-
/// Signal that the call should not return UNAVAILABLE before it has started.
pub fn wait_for_ready(mut self, wait_for_ready: bool) -> CallOption {
change_flag(
@@ -62,16 +52,6 @@ impl CallOption {
self
}
- /// Signal that the call is cacheable. gRPC is free to use GET verb.
- pub fn cacheable(mut self, cacheable: bool) -> CallOption {
- change_flag(
- &mut self.call_flags,
- grpc_sys::GRPC_INITIAL_METADATA_CACHEABLE_REQUEST,
- cacheable,
- );
- self
- }
-
/// Set write flags.
pub fn write_flags(mut self, write_flags: WriteFlags) -> CallOption {
self.write_flags = write_flags;
@@ -146,12 +126,8 @@ impl Call {
});
let share_call = Arc::new(Mutex::new(ShareCall::new(call, cq_f)));
- let sink = ClientCStreamSender::new(share_call.clone(), method.req_ser());
- let recv = ClientCStreamReceiver {
- call: share_call,
- resp_de: method.resp_de(),
- finished: false,
- };
+ let sink = ClientCStreamSender::new(share_call.clone(), method.req_ser(), opt.call_flags);
+ let recv = ClientCStreamReceiver::new(share_call, method.resp_de());
Ok((sink, recv))
}
@@ -178,12 +154,16 @@ impl Call {
)
});
- // TODO: handle header
- check_run(BatchType::Finish, |ctx, tag| unsafe {
+ let headers_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_recv_initial_metadata(call.call, ctx, tag)
});
- Ok(ClientSStreamReceiver::new(call, cq_f, method.resp_de()))
+ Ok(ClientSStreamReceiver::new(
+ call,
+ cq_f,
+ method.resp_de(),
+ headers_f,
+ ))
}
pub fn duplex_streaming<Req, Resp>(
@@ -204,14 +184,13 @@ impl Call {
)
});
- // TODO: handle header.
- check_run(BatchType::Finish, |ctx, tag| unsafe {
+ let headers_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_recv_initial_metadata(call.call, ctx, tag)
});
let share_call = Arc::new(Mutex::new(ShareCall::new(call, cq_f)));
- let sink = ClientDuplexSender::new(share_call.clone(), method.req_ser());
- let recv = ClientDuplexReceiver::new(share_call, method.resp_de());
+ let sink = ClientDuplexSender::new(share_call.clone(), method.req_ser(), opt.call_flags);
+ let recv = ClientDuplexReceiver::new(share_call, method.resp_de(), headers_f);
Ok((sink, recv))
}
}
@@ -224,6 +203,10 @@ pub struct ClientUnaryReceiver<T> {
call: Call,
resp_f: BatchFuture,
resp_de: DeserializeFn<T>,
+ finished: bool,
+ message: Option<T>,
+ initial_metadata: UnownedMetadata,
+ trailing_metadata: UnownedMetadata,
}
impl<T> ClientUnaryReceiver<T> {
@@ -232,6 +215,10 @@ impl<T> ClientUnaryReceiver<T> {
call,
resp_f,
resp_de,
+ finished: false,
+ message: None,
+ initial_metadata: UnownedMetadata::empty(),
+ trailing_metadata: UnownedMetadata::empty(),
}
}
@@ -245,15 +232,65 @@ impl<T> ClientUnaryReceiver<T> {
pub fn resp_de(&self, reader: MessageReader) -> Result<T> {
(self.resp_de)(reader)
}
+
+ async fn wait_for_batch_future(&mut self) -> Result<()> {
+ if self.finished {
+ return Ok(());
+ }
+
+ let data = Pin::new(&mut self.resp_f).await?;
+ self.initial_metadata = data.initial_metadata;
+ self.trailing_metadata = data.trailing_metadata;
+ self.message = Some(self.resp_de(data.message_reader.unwrap())?);
+ self.finished = true;
+ Ok(())
+ }
+
+ pub async fn message(&mut self) -> Result<T> {
+ self.wait_for_batch_future().await?;
+ Ok(self.message.take().unwrap())
+ }
+
+ /// Get the initial metadata.
+ pub async fn headers(&mut self) -> Result<&Metadata> {
+ self.wait_for_batch_future().await?;
+ // Because we have a reference to call, so it's safe to read.
+ Ok(unsafe { self.initial_metadata.assume_valid() })
+ }
+
+ pub async fn trailers(&mut self) -> Result<&Metadata> {
+ self.wait_for_batch_future().await?;
+ // Because we have a reference to call, so it's safe to read.
+ Ok(unsafe { self.trailing_metadata.assume_valid() })
+ }
+
+ pub fn receive_sync(&mut self) -> Result<(Metadata, T, Metadata)> {
+ block_on(async {
+ let headers = self.headers().await?.clone();
+ let message = self.message().await?;
+ let trailer = self.trailers().await?.clone();
+ Ok::<(Metadata, T, Metadata), Error>((headers, message, trailer))
+ })
+ }
}
-impl<T> Future for ClientUnaryReceiver<T> {
+impl<T: Unpin> Future for ClientUnaryReceiver<T> {
type Output = Result<T>;
+ /// Note this method is conflict with method `message`.
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>> {
+ if self.finished {
+ if let Some(message) = self.message.take() {
+ return Poll::Ready(Ok(message));
+ }
+ panic!("future should not be polled twice.");
+ }
+
let data = ready!(Pin::new(&mut self.resp_f).poll(cx)?);
- let t = self.resp_de(data.unwrap())?;
- Poll::Ready(Ok(t))
+ self.initial_metadata = data.initial_metadata;
+ self.trailing_metadata = data.trailing_metadata;
+ self.finished = true;
+ Poll::Ready(self.resp_de(data.message_reader.unwrap()))
}
}
@@ -269,9 +306,24 @@ pub struct ClientCStreamReceiver<T> {
call: Arc<Mutex<ShareCall>>,
resp_de: DeserializeFn<T>,
finished: bool,
+ message: Option<T>,
+ initial_metadata: UnownedMetadata,
+ trailing_metadata: UnownedMetadata,
}
impl<T> ClientCStreamReceiver<T> {
+ /// Private constructor to simplify code in `impl Call`
+ fn new(call: Arc<Mutex<ShareCall>>, resp_de: DeserializeFn<T>) -> ClientCStreamReceiver<T> {
+ ClientCStreamReceiver {
+ call,
+ resp_de,
+ finished: false,
+ message: None,
+ initial_metadata: UnownedMetadata::empty(),
+ trailing_metadata: UnownedMetadata::empty(),
+ }
+ }
+
/// Cancel the call.
pub fn cancel(&mut self) {
let lock = self.call.lock();
@@ -282,6 +334,41 @@ impl<T> ClientCStreamReceiver<T> {
pub fn resp_de(&self, reader: MessageReader) -> Result<T> {
(self.resp_de)(reader)
}
+
+ async fn wait_for_batch_future(&mut self) -> Result<()> {
+ if self.finished {
+ return Ok(());
+ }
+ let data = poll_fn(|cx| {
+ let mut call = self.call.lock();
+ call.poll_finish(cx)
+ })
+ .await?;
+
+ self.message = Some(self.resp_de(data.message_reader.unwrap())?);
+ self.initial_metadata = data.initial_metadata;
+ self.trailing_metadata = data.trailing_metadata;
+ self.finished = true;
+ Ok(())
+ }
+
+ pub async fn message(&mut self) -> Result<T> {
+ self.wait_for_batch_future().await?;
+ Ok(self.message.take().unwrap())
+ }
+
+ /// Get the initial metadata.
+ pub async fn headers(&mut self) -> Result<&Metadata> {
+ self.wait_for_batch_future().await?;
+ // We still have a reference in share call.
+ Ok(unsafe { self.initial_metadata.assume_valid() })
+ }
+
+ pub async fn trailers(&mut self) -> Result<&Metadata> {
+ self.wait_for_batch_future().await?;
+ // We still have a reference in share call.
+ Ok(unsafe { self.trailing_metadata.assume_valid() })
+ }
}
impl<T> Drop for ClientCStreamReceiver<T> {
@@ -294,17 +381,26 @@ impl<T> Drop for ClientCStreamReceiver<T> {
}
}
-impl<T> Future for ClientCStreamReceiver<T> {
+impl<T: Unpin> Future for ClientCStreamReceiver<T> {
type Output = Result<T>;
+ /// Note this method is conflict with method `message`.
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>> {
+ if self.finished {
+ if let Some(message) = self.message.take() {
+ return Poll::Ready(Ok(message));
+ }
+ panic!("future should not be polled twice.");
+ }
+
let data = {
let mut call = self.call.lock();
ready!(call.poll_finish(cx)?)
};
- let t = (self.resp_de)(data.unwrap())?;
+ self.initial_metadata = data.initial_metadata;
+ self.trailing_metadata = data.trailing_metadata;
self.finished = true;
- Poll::Ready(Ok(t))
+ Poll::Ready((self.resp_de)(data.message_reader.unwrap()))
}
}
@@ -318,15 +414,21 @@ pub struct StreamingCallSink<Req> {
sink_base: SinkBase,
close_f: Option<BatchFuture>,
req_ser: SerializeFn<Req>,
+ call_flags: u32,
}
impl<Req> StreamingCallSink<Req> {
- fn new(call: Arc<Mutex<ShareCall>>, req_ser: SerializeFn<Req>) -> StreamingCallSink<Req> {
+ fn new(
+ call: Arc<Mutex<ShareCall>>,
+ req_ser: SerializeFn<Req>,
+ call_flags: u32,
+ ) -> StreamingCallSink<Req> {
StreamingCallSink {
call,
sink_base: SinkBase::new(false),
close_f: None,
req_ser,
+ call_flags,
}
}
@@ -376,7 +478,7 @@ impl<Req> Sink<(Req, WriteFlags)> for StreamingCallSink<Req> {
call.check_alive()?;
}
let t = &mut *self;
- Pin::new(&mut t.sink_base).start_send(&mut t.call, &msg, flags, t.req_ser)
+ Pin::new(&mut t.sink_base).start_send(&mut t.call, &msg, flags, t.req_ser, t.call_flags)
}
#[inline]
@@ -386,7 +488,7 @@ impl<Req> Sink<(Req, WriteFlags)> for StreamingCallSink<Req> {
call.check_alive()?;
}
let t = &mut *self;
- Pin::new(&mut t.sink_base).poll_flush(cx, &mut t.call)
+ Pin::new(&mut t.sink_base).poll_flush(cx, &mut t.call, t.call_flags)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
@@ -421,22 +523,29 @@ pub type ClientCStreamSender<T> = StreamingCallSink<T>;
/// [`close`]: #method.close
pub type ClientDuplexSender<T> = StreamingCallSink<T>;
+enum FutureOrValue<F, V> {
+ Future(F),
+ Value(V),
+}
+
struct ResponseStreamImpl<H, T> {
call: H,
msg_f: Option<BatchFuture>,
read_done: bool,
finished: bool,
resp_de: DeserializeFn<T>,
+ headers_f: FutureOrValue<BatchFuture, UnownedMetadata>,
}
impl<H: ShareCallHolder + Unpin, T> ResponseStreamImpl<H, T> {
- fn new(call: H, resp_de: DeserializeFn<T>) -> ResponseStreamImpl<H, T> {
+ fn new(call: H, resp_de: DeserializeFn<T>, headers_f: BatchFuture) -> ResponseStreamImpl<H, T> {
ResponseStreamImpl {
call,
msg_f: None,
read_done: false,
finished: false,
resp_de,
+ headers_f: FutureOrValue::Future(headers_f),
}
}
@@ -459,7 +568,8 @@ impl<H: ShareCallHolder + Unpin, T> ResponseStreamImpl<H, T> {
loop {
if !self.read_done {
if let Some(msg_f) = &mut self.msg_f {
- bytes = ready!(Pin::new(msg_f).poll(cx)?);
+ let batch_result = ready!(Pin::new(msg_f).poll(cx)?);
+ bytes = batch_result.message_reader;
if bytes.is_none() {
self.read_done = true;
}
@@ -491,6 +601,17 @@ impl<H: ShareCallHolder + Unpin, T> ResponseStreamImpl<H, T> {
self.cancel();
}
}
+
+ async fn headers(&mut self) -> Result<&Metadata> {
+ if let FutureOrValue::Future(f) = &mut self.headers_f {
+ self.headers_f = FutureOrValue::Value(Pin::new(f).await?.initial_metadata);
+ }
+ match &self.headers_f {
+ // We still have reference to call.
+ FutureOrValue::Value(v) => Ok(unsafe { v.assume_valid() }),
+ _ => unreachable!(),
+ }
+ }
}
/// A receiver for server streaming call.
@@ -504,16 +625,23 @@ impl<Resp> ClientSStreamReceiver<Resp> {
call: Call,
finish_f: BatchFuture,
de: DeserializeFn<Resp>,
+ headers_f: BatchFuture,
) -> ClientSStreamReceiver<Resp> {
let share_call = ShareCall::new(call, finish_f);
ClientSStreamReceiver {
- imp: ResponseStreamImpl::new(share_call, de),
+ imp: ResponseStreamImpl::new(share_call, de, headers_f),
}
}
pub fn cancel(&mut self) {
self.imp.cancel()
}
+
+ /// Get the initial metadata.
+ #[inline]
+ pub async fn headers(&mut self) -> Result<&Metadata> {
+ self.imp.headers().await
+ }
}
impl<Resp> Stream for ClientSStreamReceiver<Resp> {
@@ -538,15 +666,25 @@ pub struct ClientDuplexReceiver<Resp> {
}
impl<Resp> ClientDuplexReceiver<Resp> {
- fn new(call: Arc<Mutex<ShareCall>>, de: DeserializeFn<Resp>) -> ClientDuplexReceiver<Resp> {
+ fn new(
+ call: Arc<Mutex<ShareCall>>,
+ de: DeserializeFn<Resp>,
+ headers_f: BatchFuture,
+ ) -> ClientDuplexReceiver<Resp> {
ClientDuplexReceiver {
- imp: ResponseStreamImpl::new(call, de),
+ imp: ResponseStreamImpl::new(call, de, headers_f),
}
}
pub fn cancel(&mut self) {
self.imp.cancel()
}
+
+ /// Get the initial metadata.
+ #[inline]
+ pub async fn headers(&mut self) -> Result<&Metadata> {
+ self.imp.headers().await
+ }
}
impl<Resp> Drop for ClientDuplexReceiver<Resp> {
diff --git a/src/call/mod.rs b/src/call/mod.rs
index 25174ea..93415fd 100644
--- a/src/call/mod.rs
+++ b/src/call/mod.rs
@@ -3,16 +3,18 @@
pub mod client;
pub mod server;
+use std::ffi::CStr;
use std::fmt::{self, Debug, Display};
+use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
+use std::task::{Context, Poll};
use std::{ptr, slice};
use crate::grpc_sys::{self, grpc_call, grpc_call_error, grpcwrap_batch_context};
+use crate::metadata::UnownedMetadata;
use crate::{cq::CompletionQueue, Metadata, MetadataBuilder};
-use futures::future::Future;
-use futures::ready;
-use futures::task::{Context, Poll};
+use futures_util::ready;
use libc::c_void;
use parking_lot::Mutex;
@@ -20,7 +22,7 @@ use crate::buf::{GrpcByteBuffer, GrpcByteBufferReader, GrpcSlice};
use crate::codec::{DeserializeFn, Marshaller, SerializeFn};
use crate::error::{Error, Result};
use crate::grpc_sys::grpc_status_code::*;
-use crate::task::{self, BatchFuture, BatchType, CallTag};
+use crate::task::{self, BatchFuture, BatchResult, BatchType, CallTag};
/// An gRPC status code structure.
/// This type contains constants for all gRPC status codes.
@@ -165,6 +167,9 @@ pub struct RpcStatus {
///
/// See also https://grpc.io/docs/guides/error/#richer-error-model.
details: Vec<u8>,
+
+ /// Debug error string
+ debug_error_string: String,
}
impl Display for RpcStatus {
@@ -187,17 +192,32 @@ impl RpcStatus {
/// Create a new [`RpcStats`] with code, message and details.
///
/// If using rich error model, `details` should be binary message that sets `code` and
- /// `message` to the same value. Or you can use `into` method to do automatical
+ /// `message` to the same value. Or you can use `into` method to do automatic
/// transformation if using `grpcio_proto::google::rpc::Status`.
pub fn with_details<T: Into<RpcStatusCode>>(
code: T,
message: String,
details: Vec<u8>,
) -> RpcStatus {
+ RpcStatus::with_details_and_error_string(code, message, details, String::new())
+ }
+
+ /// Create a new [`RpcStats`] with code, message, details and debug error string.
+ ///
+ /// If using rich error model, `details` should be binary message that sets `code` and
+ /// `message` to the same value. Or you can use `into` method to do automatic
+ /// transformation if using `grpcio_proto::google::rpc::Status`.
+ pub fn with_details_and_error_string<T: Into<RpcStatusCode>>(
+ code: T,
+ message: String,
+ details: Vec<u8>,
+ debug_error_string: String,
+ ) -> RpcStatus {
RpcStatus {
code: code.into(),
message,
details,
+ debug_error_string,
}
}
@@ -224,6 +244,15 @@ impl RpcStatus {
pub fn details(&self) -> &[u8] {
&self.details
}
+
+ /// Return the debug error string.
+ ///
+ /// This will return a detailed string of the gRPC Core error that led to the failure.
+ /// It shouldn't be relied upon for anything other than gaining more debug data in
+ /// failure cases.
+ pub fn debug_error_string(&self) -> &str {
+ &self.debug_error_string
+ }
}
pub type MessageReader = GrpcByteBufferReader;
@@ -276,7 +305,18 @@ impl BatchContext {
);
let metadata = &*(m_ptr as *const Metadata);
let details = metadata.search_binary_error_details().to_vec();
- RpcStatus::with_details(status, message, details)
+
+ let error_string_ptr =
+ grpc_sys::grpcwrap_batch_context_recv_status_on_client_error_string(self.ctx);
+ let error_string = if error_string_ptr.is_null() {
+ String::new()
+ } else {
+ CStr::from_ptr(error_string_ptr)
+ .to_string_lossy()
+ .into_owned()
+ };
+
+ RpcStatus::with_details_and_error_string(status, message, details, error_string)
}
}
}
@@ -286,6 +326,36 @@ impl BatchContext {
let buf = self.take_recv_message()?;
Some(GrpcByteBufferReader::new(buf))
}
+
+ /// Get the initial metadata from response.
+ ///
+ /// If initial metadata is not fetched or the method has been called, empty metadata will be
+ /// returned.
+ pub fn take_initial_metadata(&mut self) -> UnownedMetadata {
+ let mut res = UnownedMetadata::empty();
+ unsafe {
+ grpcio_sys::grpcwrap_batch_context_take_recv_initial_metadata(
+ self.ctx,
+ res.as_mut_ptr(),
+ );
+ }
+ res
+ }
+
+ /// Get the trailing metadata from response.
+ ///
+ /// If trailing metadata is not fetched or the method has been called, empty metadata will be
+ /// returned.
+ pub fn take_trailing_metadata(&mut self) -> UnownedMetadata {
+ let mut res = UnownedMetadata::empty();
+ unsafe {
+ grpc_sys::grpcwrap_batch_context_take_recv_status_on_client_trailing_metadata(
+ self.ctx,
+ res.as_mut_ptr(),
+ );
+ }
+ res
+ }
}
impl Drop for BatchContext {
@@ -295,11 +365,11 @@ impl Drop for BatchContext {
}
#[inline]
-fn box_batch_tag(tag: CallTag) -> (*mut grpcwrap_batch_context, *mut c_void) {
+fn box_batch_tag(tag: CallTag) -> (*mut grpcwrap_batch_context, *mut CallTag) {
let tag_box = Box::new(tag);
(
tag_box.batch_ctx().unwrap().as_ptr(),
- Box::into_raw(tag_box) as _,
+ Box::into_raw(tag_box),
)
}
@@ -310,10 +380,10 @@ where
{
let (cq_f, tag) = CallTag::batch_pair(bt);
let (batch_ptr, tag_ptr) = box_batch_tag(tag);
- let code = f(batch_ptr, tag_ptr);
+ let code = f(batch_ptr, tag_ptr as *mut c_void);
if code != grpc_call_error::GRPC_CALL_OK {
unsafe {
- Box::from_raw(tag_ptr);
+ drop(Box::from_raw(tag_ptr));
}
panic!("create call fail: {:?}", code);
}
@@ -343,17 +413,18 @@ impl Call {
&mut self,
msg: &mut GrpcSlice,
write_flags: u32,
- initial_meta: bool,
+ initial_metadata: Option<&mut Metadata>,
+ call_flags: u32,
) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
- let i = if initial_meta { 1 } else { 0 };
let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_send_message(
self.call,
ctx,
msg.as_mut_ptr(),
write_flags,
- i,
+ initial_metadata.map_or_else(ptr::null_mut, |m| m as *mut _ as _),
+ call_flags,
tag,
)
});
@@ -393,12 +464,18 @@ impl Call {
pub fn start_send_status_from_server(
&mut self,
status: &RpcStatus,
+ initial_metadata: &mut Option<Metadata>,
+ call_flags: u32,
send_empty_metadata: bool,
payload: &mut Option<GrpcSlice>,
write_flags: u32,
) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
- let send_empty_metadata = if send_empty_metadata { 1 } else { 0 };
+
+ if initial_metadata.is_none() && send_empty_metadata {
+ initial_metadata.replace(MetadataBuilder::new().build());
+ }
+
let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
let (msg_ptr, msg_len) = if status.code() == RpcStatusCode::OK {
(ptr::null(), 0)
@@ -409,7 +486,7 @@ impl Call {
Some(p) => p.as_mut_ptr(),
None => ptr::null_mut(),
};
- let mut trailing_metadata = if status.details.is_empty() {
+ let mut trailing_metadata: Option<Metadata> = if status.details.is_empty() {
None
} else {
let mut builder = MetadataBuilder::new();
@@ -422,10 +499,13 @@ impl Call {
status.code().into(),
msg_ptr as _,
msg_len,
+ initial_metadata
+ .as_mut()
+ .map_or_else(ptr::null_mut, |m| m as *mut _ as _),
+ call_flags,
trailing_metadata
.as_mut()
.map_or_else(ptr::null_mut, |m| m as *mut _ as _),
- send_empty_metadata,
payload_p,
write_flags,
tag,
@@ -458,8 +538,9 @@ impl Call {
status.code().into(),
msg_ptr as _,
msg_len,
+ (&mut MetadataBuilder::new().build()) as *mut _ as _,
+ 0,
ptr::null_mut(),
- 1,
ptr::null_mut(),
0,
tag_ptr as *mut c_void,
@@ -467,7 +548,7 @@ impl Call {
};
if code != grpc_call_error::GRPC_CALL_OK {
unsafe {
- Box::from_raw(tag_ptr);
+ drop(Box::from_raw(tag_ptr));
}
panic!("create call fail: {:?}", code);
}
@@ -518,7 +599,7 @@ impl ShareCall {
/// Poll if the call is still alive.
///
/// If the call is still running, will register a notification for its completion.
- fn poll_finish(&mut self, cx: &mut Context) -> Poll<Result<Option<MessageReader>>> {
+ fn poll_finish(&mut self, cx: &mut Context) -> Poll<Result<BatchResult>> {
let res = match Pin::new(&mut self.close_f).poll(cx) {
Poll::Ready(Ok(reader)) => {
self.status = Some(RpcStatus::ok());
@@ -603,7 +684,7 @@ impl StreamingBase {
let mut bytes = None;
if !self.read_done {
if let Some(msg_f) = &mut self.msg_f {
- bytes = ready!(Pin::new(msg_f).poll(cx)?);
+ bytes = ready!(Pin::new(msg_f).poll(cx)?).message_reader;
if bytes.is_none() {
self.read_done = true;
}
@@ -682,6 +763,7 @@ impl WriteFlags {
struct SinkBase {
// Batch job to be executed in `poll_ready`.
batch_f: Option<BatchFuture>,
+ headers: Metadata,
send_metadata: bool,
// Flag to indicate if enhance batch strategy. This behavior will modify the `buffer_hint` to batch
// messages as much as possible.
@@ -699,11 +781,12 @@ impl SinkBase {
fn new(send_metadata: bool) -> SinkBase {
SinkBase {
batch_f: None,
+ headers: MetadataBuilder::new().build(),
+ send_metadata,
+ enhance_buffer_strategy: false,
buffer: GrpcSlice::default(),
buf_flags: None,
last_buf_hint: true,
- send_metadata,
- enhance_buffer_strategy: false,
}
}
@@ -713,13 +796,14 @@ impl SinkBase {
t: &T,
flags: WriteFlags,
ser: SerializeFn<T>,
+ call_flags: u32,
) -> Result<()> {
// temporary fix: buffer hint with send meta will not send out any metadata.
// note: only the first message can enter this code block.
if self.send_metadata {
ser(t, &mut self.buffer)?;
self.buf_flags = Some(flags);
- self.start_send_buffer_message(false, call)?;
+ self.start_send_buffer_message(false, call, call_flags)?;
self.send_metadata = false;
return Ok(());
}
@@ -727,7 +811,7 @@ impl SinkBase {
// If there is already a buffered message waiting to be sent, set `buffer_hint` to true to indicate
// that this is not the last message.
if self.buf_flags.is_some() {
- self.start_send_buffer_message(true, call)?;
+ self.start_send_buffer_message(true, call, call_flags)?;
}
ser(t, &mut self.buffer)?;
@@ -737,7 +821,7 @@ impl SinkBase {
// If sink disable batch, start sending the message in buffer immediately.
if !self.enhance_buffer_strategy {
- self.start_send_buffer_message(hint, call)?;
+ self.start_send_buffer_message(hint, call, call_flags)?;
}
Ok(())
@@ -760,12 +844,13 @@ impl SinkBase {
&mut self,
cx: &mut Context,
call: &mut C,
+ call_flags: u32,
) -> Poll<Result<()>> {
if self.batch_f.is_some() {
ready!(self.poll_ready(cx)?);
}
if self.buf_flags.is_some() {
- self.start_send_buffer_message(self.last_buf_hint, call)?;
+ self.start_send_buffer_message(self.last_buf_hint, call, call_flags)?;
ready!(self.poll_ready(cx)?);
}
self.last_buf_hint = true;
@@ -777,15 +862,24 @@ impl SinkBase {
&mut self,
buffer_hint: bool,
call: &mut C,
+ call_flags: u32,
) -> Result<()> {
// `start_send` is supposed to be called after `poll_ready` returns ready.
assert!(self.batch_f.is_none());
+ let buffer = &mut self.buffer;
let mut flags = self.buf_flags.unwrap();
flags = flags.buffer_hint(buffer_hint);
+
+ let headers = if self.send_metadata {
+ Some(&mut self.headers)
+ } else {
+ None
+ };
+
let write_f = call.call(|c| {
c.call
- .start_send_message(&mut self.buffer, flags.flags, self.send_metadata)
+ .start_send_message(buffer, flags.flags, headers, call_flags)
})?;
self.batch_f = Some(write_f);
if !self.buffer.is_inline() {
diff --git a/src/call/server.rs b/src/call/server.rs
index e762889..35149e0 100644
--- a/src/call/server.rs
+++ b/src/call/server.rs
@@ -1,19 +1,18 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
use std::ffi::CStr;
+use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
+use std::task::{Context, Poll};
use std::time::Duration;
use std::{result, slice};
use crate::grpc_sys::{
self, gpr_clock_type, gpr_timespec, grpc_call_error, grpcwrap_request_call_context,
};
-use futures::future::Future;
-use futures::ready;
-use futures::sink::Sink;
-use futures::stream::Stream;
-use futures::task::{Context, Poll};
+use futures_util::ready;
+use futures_util::{Sink, Stream};
use parking_lot::Mutex;
use super::{RpcStatus, ShareCall, ShareCallHolder, WriteFlags};
@@ -126,7 +125,7 @@ impl RequestContext {
let call = grpc_sys::grpcwrap_request_call_context_get_call(request_ctx);
let code = grpc_sys::grpcwrap_call_recv_message(call, batch_ctx, tag_ptr as _);
if code != grpc_call_error::GRPC_CALL_OK {
- Box::from_raw(tag_ptr);
+ drop(Box::from_raw(tag_ptr));
// it should not failed.
panic!("try to receive message fail: {:?}", code);
}
@@ -338,6 +337,8 @@ macro_rules! impl_unary_sink {
call: Option<$holder>,
write_flags: u32,
ser: SerializeFn<T>,
+ headers: Option<Metadata>,
+ call_flags: u32,
}
impl<T> $t<T> {
@@ -346,9 +347,22 @@ macro_rules! impl_unary_sink {
call: Some(call),
write_flags: 0,
ser,
+ headers: None,
+ call_flags: 0,
}
}
+ #[inline]
+ pub fn set_headers(&mut self, meta: Metadata) {
+ self.headers = Some(meta);
+ }
+
+ #[inline]
+ pub fn set_call_flags(&mut self, flags: u32) {
+ // TODO: implement a server-side call flags interface similar to the client-side .CallOption.
+ self.call_flags = flags;
+ }
+
pub fn success(self, t: T) -> $rt {
self.complete(RpcStatus::ok(), Some(t))
}
@@ -373,10 +387,13 @@ macro_rules! impl_unary_sink {
None => None,
};
+ let headers = &mut self.headers;
+ let call_flags = self.call_flags;
let write_flags = self.write_flags;
+
let res = self.call.as_mut().unwrap().call(|c| {
c.call
- .start_send_status_from_server(&status, true, &mut data, write_flags)
+ .start_send_status_from_server(&status, headers, call_flags, true, &mut data, write_flags)
});
let (cq_f, err) = match res {
@@ -456,6 +473,10 @@ macro_rules! impl_stream_sink {
}
}
+ pub fn set_headers(&mut self, meta: Metadata) {
+ self.base.headers = meta;
+ }
+
/// By default it always sends messages with their configured buffer hint. But when the
/// `enhance_batch` is enabled, messages will be batched together as many as possible.
/// The rules are listed as below:
@@ -479,7 +500,7 @@ macro_rules! impl_stream_sink {
let send_metadata = self.base.send_metadata;
let res = self.call.as_mut().unwrap().call(|c| {
c.call
- .start_send_status_from_server(&status, send_metadata, &mut None, 0)
+ .start_send_status_from_server(&status, &mut None, 0, send_metadata, &mut None, 0)
});
let (fail_f, err) = match res {
@@ -524,7 +545,7 @@ macro_rules! impl_stream_sink {
#[inline]
fn start_send(mut self: Pin<&mut Self>, (msg, flags): (T, WriteFlags)) -> Result<()> {
let t = &mut *self;
- t.base.start_send(t.call.as_mut().unwrap(), &msg, flags, t.ser)
+ t.base.start_send(t.call.as_mut().unwrap(), &msg, flags, t.ser, 0)
}
#[inline]
@@ -533,7 +554,7 @@ macro_rules! impl_stream_sink {
return Poll::Ready(Err(Error::RemoteStopped));
}
let t = &mut *self;
- Pin::new(&mut t.base).poll_flush(cx, t.call.as_mut().unwrap())
+ Pin::new(&mut t.base).poll_flush(cx, t.call.as_mut().unwrap(), 0)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
@@ -545,7 +566,7 @@ macro_rules! impl_stream_sink {
let status = &t.status;
let flush_f = t.call.as_mut().unwrap().call(|c| {
c.call
- .start_send_status_from_server(status, send_metadata, &mut None, 0)
+ .start_send_status_from_server(status, &mut None, 0, send_metadata, &mut None, 0)
})?;
t.flush_f = Some(flush_f);
}
@@ -714,7 +735,7 @@ pub fn execute_unary<P, Q, F>(
Err(e) => {
let status = RpcStatus::with_message(
RpcStatusCode::INTERNAL,
- format!("Failed to deserialize response message: {:?}", e),
+ format!("Failed to deserialize response message: {e:?}"),
);
call.abort(&status);
return;
@@ -760,7 +781,7 @@ pub fn execute_server_streaming<P, Q, F>(
Err(e) => {
let status = RpcStatus::with_message(
RpcStatusCode::INTERNAL,
- format!("Failed to deserialize response message: {:?}", e),
+ format!("Failed to deserialize response message: {e:?}"),
);
call.abort(&status);
return;
diff --git a/src/channel.rs b/src/channel.rs
index c017f13..057b614 100644
--- a/src/channel.rs
+++ b/src/channel.rs
@@ -21,8 +21,8 @@ use crate::env::Environment;
use crate::error::Result;
use crate::task::CallTag;
use crate::task::Kicker;
-use crate::CallOption;
-use crate::ResourceQuota;
+use crate::{CallOption, ChannelCredentials};
+use crate::{ResourceQuota, RpcStatusCode};
pub use crate::grpc_sys::{
grpc_compression_algorithm as CompressionAlgorithms,
@@ -31,12 +31,12 @@ pub use crate::grpc_sys::{
/// Ref: http://www.grpc.io/docs/guides/wire.html#user-agents
fn format_user_agent_string(agent: &str) -> CString {
- let version = "0.9.1";
+ let version = env!("CARGO_PKG_VERSION");
let trimed_agent = agent.trim();
let val = if trimed_agent.is_empty() {
- format!("grpc-rust/{}", version)
+ format!("grpc-rust/{version}")
} else {
- format!("{} grpc-rust/{}", trimed_agent, version)
+ format!("{trimed_agent} grpc-rust/{version}")
};
CString::new(val).unwrap()
}
@@ -73,6 +73,7 @@ pub enum LbPolicy {
pub struct ChannelBuilder {
env: Arc<Environment>,
options: HashMap<Cow<'static, [u8]>, Options>,
+ credentials: Option<ChannelCredentials>,
}
impl ChannelBuilder {
@@ -81,6 +82,7 @@ impl ChannelBuilder {
ChannelBuilder {
env,
options: HashMap::new(),
+ credentials: None,
}
}
@@ -182,10 +184,9 @@ impl ChannelBuilder {
/// Set whether to allow the use of `SO_REUSEPORT` if available. Defaults to `true`.
pub fn reuse_port(mut self, reuse: bool) -> ChannelBuilder {
- let opt = if reuse { 1 } else { 0 };
self.options.insert(
Cow::Borrowed(grpcio_sys::GRPC_ARG_ALLOW_REUSEPORT),
- Options::Integer(opt),
+ Options::Integer(reuse as i32),
);
self
}
@@ -296,6 +297,15 @@ impl ChannelBuilder {
self
}
+ /// If set to zero, disables use of http proxies.
+ pub fn enable_http_proxy(mut self, num: bool) -> ChannelBuilder {
+ self.options.insert(
+ Cow::Borrowed(grpcio_sys::GRPC_ARG_ENABLE_HTTP_PROXY),
+ Options::Integer(num as i32),
+ );
+ self
+ }
+
/// Set default compression algorithm for the channel.
pub fn default_compression_algorithm(mut self, algo: CompressionAlgorithms) -> ChannelBuilder {
self.options.insert(
@@ -305,6 +315,29 @@ impl ChannelBuilder {
self
}
+ /// Set default gzip compression level.
+ #[cfg(feature = "nightly")]
+ pub fn default_gzip_compression_level(mut self, level: usize) -> ChannelBuilder {
+ self.options.insert(
+ Cow::Borrowed(grpcio_sys::GRPC_ARG_GZIP_COMPRESSION_LEVEL),
+ Options::Integer(level as i32),
+ );
+ self
+ }
+
+ /// Set default grpc min message size to compression.
+ #[cfg(feature = "nightly")]
+ pub fn default_grpc_min_message_size_to_compress(
+ mut self,
+ lower_bound: usize,
+ ) -> ChannelBuilder {
+ self.options.insert(
+ Cow::Borrowed(grpcio_sys::GRPC_ARG_MIN_MESSAGE_SIZE_TO_COMPRESS),
+ Options::Integer(lower_bound as i32),
+ );
+ self
+ }
+
/// Set default compression level for the channel.
pub fn default_compression_level(mut self, level: CompressionLevel) -> ChannelBuilder {
self.options.insert(
@@ -373,6 +406,30 @@ impl ChannelBuilder {
self
}
+ /// Set use local subchannel pool
+ ///
+ /// This method allows channel use it's owned subchannel pool.
+ pub fn use_local_subchannel_pool(mut self, enable: bool) -> ChannelBuilder {
+ self.options.insert(
+ Cow::Borrowed(grpcio_sys::GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL),
+ Options::Integer(enable as i32),
+ );
+ self
+ }
+
+ /// Enables retry functionality. Defaults to true. When enabled, transparent
+ /// retries will be performed as appropriate, and configurable retries are
+ /// enabled when they are configured via the service config. For details, see:
+ /// https://github.com/grpc/proposal/blob/master/A6-client-retries.md
+ /// NOTE: Hedging functionality is not yet implemented.
+ pub fn enable_retry(mut self, enable: bool) -> ChannelBuilder {
+ self.options.insert(
+ Cow::Borrowed(grpcio_sys::GRPC_ARG_ENABLE_RETRIES),
+ Options::Integer(enable as i32),
+ );
+ self
+ }
+
/// Set a raw integer configuration.
///
/// This method is only for bench usage, users should use the encapsulated API instead.
@@ -438,18 +495,21 @@ impl ChannelBuilder {
self.build_args()
}
- /// Build an insecure [`Channel`] that connects to a specific address.
+ /// Build an [`Channel`] that connects to a specific address.
pub fn connect(mut self, addr: &str) -> Channel {
let args = self.prepare_connect_args();
let addr = CString::new(addr).unwrap();
let addr_ptr = addr.as_ptr();
+ let mut creds = self
+ .credentials
+ .unwrap_or_else(ChannelCredentials::insecure);
let channel =
- unsafe { grpc_sys::grpc_insecure_channel_create(addr_ptr, args.args, ptr::null_mut()) };
+ unsafe { grpcio_sys::grpc_channel_create(addr_ptr, creds.as_mut_ptr(), args.args) };
unsafe { Channel::new(self.env.pick_cq(), self.env, channel) }
}
- /// Build an insecure [`Channel`] taking over an established connection from
+ /// Build an [`Channel`] taking over an established connection from
/// a file descriptor. The target string given is purely informative to
/// describe the endpoint of the connection. Takes ownership of the given
/// file descriptor and will close it when the connection is closed.
@@ -466,23 +526,25 @@ impl ChannelBuilder {
let args = self.prepare_connect_args();
let target = CString::new(target).unwrap();
let target_ptr = target.as_ptr();
- let channel = grpc_sys::grpc_insecure_channel_create_from_fd(target_ptr, fd, args.args);
+ // Actually only insecure credentials are supported currently.
+ let mut creds = self
+ .credentials
+ .unwrap_or_else(ChannelCredentials::insecure);
+ let channel =
+ grpcio_sys::grpc_channel_create_from_fd(target_ptr, fd, creds.as_mut_ptr(), args.args);
Channel::new(self.env.pick_cq(), self.env, channel)
}
}
-#[cfg(feature = "secure")]
+#[cfg(feature = "_secure")]
mod secure_channel {
use std::borrow::Cow;
use std::ffi::CString;
- use std::ptr;
-
- use crate::grpc_sys;
use crate::ChannelCredentials;
- use super::{Channel, ChannelBuilder, Options};
+ use super::{ChannelBuilder, Options};
const OPT_SSL_TARGET_NAME_OVERRIDE: &[u8] = b"grpc.ssl_target_name_override\0";
@@ -501,21 +563,10 @@ mod secure_channel {
self
}
- /// Build a secure [`Channel`] that connects to a specific address.
- pub fn secure_connect(mut self, addr: &str, mut creds: ChannelCredentials) -> Channel {
- let args = self.prepare_connect_args();
- let addr = CString::new(addr).unwrap();
- let addr_ptr = addr.as_ptr();
- let channel = unsafe {
- grpc_sys::grpc_secure_channel_create(
- creds.as_mut_ptr(),
- addr_ptr,
- args.args,
- ptr::null_mut(),
- )
- };
-
- unsafe { Channel::new(self.env.pick_cq(), self.env, channel) }
+ /// Set the credentials used to build the connection.
+ pub fn set_credentials(mut self, creds: ChannelCredentials) -> ChannelBuilder {
+ self.credentials = Some(creds);
+ self
}
}
}
@@ -545,8 +596,9 @@ impl ChannelInner {
// If try_to_connect is true, the channel will try to establish a connection, potentially
// changing the state.
fn check_connectivity_state(&self, try_to_connect: bool) -> ConnectivityState {
- let should_try = if try_to_connect { 1 } else { 0 };
- unsafe { grpc_sys::grpc_channel_check_connectivity_state(self.channel, should_try) }
+ unsafe {
+ grpc_sys::grpc_channel_check_connectivity_state(self.channel, try_to_connect as _)
+ }
}
}
@@ -570,6 +622,7 @@ pub struct Channel {
cq: CompletionQueue,
}
+#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl Send for Channel {}
unsafe impl Sync for Channel {}
@@ -593,6 +646,19 @@ impl Channel {
}
}
+ /// Create a lame channel that will fail all its operations.
+ pub fn lame(env: Arc<Environment>, target: &str) -> Channel {
+ unsafe {
+ let target = CString::new(target).unwrap();
+ let ch = grpc_sys::grpc_lame_client_channel_create(
+ target.as_ptr(),
+ RpcStatusCode::UNAVAILABLE.into(),
+ b"call on lame client\0".as_ptr() as _,
+ );
+ Self::new(env.pick_cq(), env, ch)
+ }
+ }
+
/// If try_to_connect is true, the channel will try to establish a connection, potentially
/// changing the state.
pub fn check_connectivity_state(&self, try_to_connect: bool) -> ConnectivityState {
@@ -712,3 +778,26 @@ impl Channel {
&self.cq
}
}
+
+#[cfg(test)]
+#[cfg(feature = "nightly")]
+mod tests {
+ use crate::env::Environment;
+ use crate::ChannelBuilder;
+ use std::sync::Arc;
+
+ #[test]
+ #[cfg(feature = "nightly")]
+ fn test_grpc_min_message_size_to_compress() {
+ let env = Arc::new(Environment::new(1));
+ let cb = ChannelBuilder::new(env);
+ cb.default_grpc_min_message_size_to_compress(1);
+ }
+ #[test]
+ #[cfg(feature = "nightly")]
+ fn test_gzip_compression_level() {
+ let env = Arc::new(Environment::new(1));
+ let cb = ChannelBuilder::new(env);
+ cb.default_gzip_compression_level(1);
+ }
+}
diff --git a/src/channelz.rs b/src/channelz.rs
new file mode 100644
index 0000000..65180bc
--- /dev/null
+++ b/src/channelz.rs
@@ -0,0 +1,133 @@
+// Copyright 2021 TiKV Project Authors. Licensed under Apache-2.0.
+
+//! Channelz provides channel level debug information. In short, There are four types of
+//! top level entities: channel, subchannel, socket and server. All entities are
+//! identified by an positive unique integer, which is allocated in order. For more
+//! explanation, see https://github.com/grpc/proposal/blob/master/A14-channelz.md.
+//!
+//! A full support requires a service that allow remote querying. But for now it's
+//! too complicated to add full support. Because gRPC C core exposes the information
+//! using JSON format, and there is no protobuf library that supports parsing json
+//! format in Rust. So this module only provides safe APIs to access the informations.
+
+use std::ffi::CStr;
+use std::{cmp, str};
+
+macro_rules! visit {
+ ($ptr:expr, $visitor:ident) => {{
+ let s_ptr = $ptr;
+ let res;
+ if !s_ptr.is_null() {
+ let c_s = CStr::from_ptr(s_ptr);
+ // It's json string, so it must be utf8 compatible.
+ let s = str::from_utf8_unchecked(c_s.to_bytes());
+ res = $visitor(s);
+ grpcio_sys::gpr_free(s_ptr as _);
+ } else {
+ res = $visitor("");
+ }
+ res
+ }};
+}
+
+/// Gets all root channels (i.e. channels the application has directly created). This
+/// does not include subchannels nor non-top level channels.
+pub fn get_top_channels<V, R>(start_channel_id: u64, visitor: V) -> R
+where
+ V: FnOnce(&str) -> R,
+{
+ unsafe {
+ visit!(
+ grpcio_sys::grpc_channelz_get_top_channels(start_channel_id as _),
+ visitor
+ )
+ }
+}
+
+/// Gets all servers that exist in the process.
+pub fn get_servers<V, R>(start_server_id: u64, visitor: V) -> R
+where
+ V: FnOnce(&str) -> R,
+{
+ unsafe {
+ visit!(
+ grpcio_sys::grpc_channelz_get_servers(start_server_id as _),
+ visitor
+ )
+ }
+}
+
+/// Returns a single Server, or else an empty string.
+pub fn get_server<V, R>(server_id: u64, visitor: V) -> R
+where
+ V: FnOnce(&str) -> R,
+{
+ unsafe {
+ visit!(
+ grpcio_sys::grpc_channelz_get_server(server_id as _),
+ visitor
+ )
+ }
+}
+
+/// Gets all server sockets that exist in the server.
+pub fn get_server_sockets<V, R>(
+ server_id: u64,
+ start_socket_id: u64,
+ max_results: usize,
+ visitor: V,
+) -> R
+where
+ V: FnOnce(&str) -> R,
+{
+ let max_results = cmp::min(isize::MAX as usize, max_results) as isize;
+ unsafe {
+ visit!(
+ grpcio_sys::grpc_channelz_get_server_sockets(
+ server_id as _,
+ start_socket_id as _,
+ max_results
+ ),
+ visitor
+ )
+ }
+}
+
+/// Returns a single Channel, or else an empty string.
+pub fn get_channel<V, R>(channel_id: u64, visitor: V) -> R
+where
+ V: FnOnce(&str) -> R,
+{
+ unsafe {
+ visit!(
+ grpcio_sys::grpc_channelz_get_channel(channel_id as _),
+ visitor
+ )
+ }
+}
+
+/// Returns a single Subchannel, or else an empty string.
+pub fn get_subchannel<V, R>(subchannel_id: u64, visitor: V) -> R
+where
+ V: FnOnce(&str) -> R,
+{
+ unsafe {
+ visit!(
+ grpcio_sys::grpc_channelz_get_subchannel(subchannel_id as _),
+ visitor
+ )
+ }
+}
+
+/// Returns a single Socket, or else an empty string.
+pub fn get_socket<V, R>(socket_id: u64, visitor: V) -> R
+where
+ V: FnOnce(&str) -> R,
+{
+ unsafe {
+ visit!(
+ grpcio_sys::grpc_channelz_get_socket(socket_id as _),
+ visitor
+ )
+ }
+}
diff --git a/src/client.rs b/src/client.rs
index 4cce793..b7664b8 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,5 +1,7 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
+use std::future::Future;
+
use crate::call::client::{
CallOption, ClientCStreamReceiver, ClientCStreamSender, ClientDuplexReceiver,
ClientDuplexSender, ClientSStreamReceiver, ClientUnaryReceiver,
@@ -9,8 +11,7 @@ use crate::channel::Channel;
use crate::error::Result;
use crate::task::Executor;
use crate::task::Kicker;
-use futures::executor::block_on;
-use futures::Future;
+use futures_executor::block_on;
/// A generic client for making RPC calls.
#[derive(Clone)]
@@ -29,9 +30,9 @@ impl Client {
/// Create a synchronized unary RPC call.
///
- /// It uses futures::executor::block_on to wait for the futures. It's recommended to use
+ /// It uses futures_executor::block_on to wait for the futures. It's recommended to use
/// the asynchronous version.
- pub fn unary_call<Req, Resp>(
+ pub fn unary_call<Req, Resp: Unpin>(
&self,
method: &Method<Req, Resp>,
req: &Req,
@@ -97,4 +98,10 @@ impl Client {
let kicker = self.kicker.clone();
Executor::new(self.channel.cq()).spawn(f, kicker)
}
+
+ /// Get the underlying channel.
+ #[inline]
+ pub fn channel(&self) -> &Channel {
+ &self.channel
+ }
}
diff --git a/src/codec.rs b/src/codec.rs
index 35dfb2e..e4449e6 100644
--- a/src/codec.rs
+++ b/src/codec.rs
@@ -50,7 +50,7 @@ pub mod pb_codec {
}
} else {
Err(Error::Codec(
- format!("message is too large: {} > {}", cap, MAX_MESSAGE_SIZE).into(),
+ format!("message is too large: {cap} > {MAX_MESSAGE_SIZE}").into(),
))
}
}
@@ -85,7 +85,7 @@ pub mod pr_codec {
Ok(())
} else {
Err(Error::Codec(
- format!("message is too large: {} > {}", size, MAX_MESSAGE_SIZE).into(),
+ format!("message is too large: {size} > {MAX_MESSAGE_SIZE}").into(),
))
}
}
diff --git a/src/env.rs b/src/env.rs
index 5c2e199..2cc2216 100644
--- a/src/env.rs
+++ b/src/env.rs
@@ -95,7 +95,7 @@ impl EnvBuilder {
let tx_i = tx.clone();
let mut builder = ThreadBuilder::new();
if let Some(ref prefix) = self.name_prefix {
- builder = builder.name(format!("{}-{}", prefix, i));
+ builder = builder.name(format!("{prefix}-{i}"));
}
let after_start = self.after_start.clone();
let before_stop = self.before_stop.clone();
diff --git a/src/error.rs b/src/error.rs
index 2d65eb2..260425c 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,5 +1,6 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
+use std::ffi::CString;
use std::{error, fmt, result};
use crate::call::RpcStatus;
@@ -24,7 +25,7 @@ pub enum Error {
/// Failed to shutdown.
ShutdownFailed,
/// Failed to bind.
- BindFail(String, u16),
+ BindFail(CString),
/// gRPC completion queue is shutdown.
QueueShutdown,
/// Failed to create Google default credentials.
@@ -43,7 +44,7 @@ impl fmt::Display for Error {
write!(fmt, "RpcFailure: {} {}", s.code(), s.message())
}
}
- other_error => write!(fmt, "{:?}", other_error),
+ other_error => write!(fmt, "{other_error:?}"),
}
}
}
diff --git a/src/lib.rs b/src/lib.rs
index fd147af..0e5d225 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -12,8 +12,10 @@ framework that puts mobile and HTTP/2 first. grpcio is built on [gRPC Core] and
## Optional features
-- **`secure`** *(enabled by default)* - Enables support for TLS encryption and some authentication
+- **`boringssl`** *(enabled by default)* - Enables support for TLS encryption and some authentication
mechanisms.
+- **`openssl`** - Same as `boringssl`, but base on the system openssl.
+- **`openssl-vendored`** - Same as `openssl`, but build openssl from source.
*/
@@ -21,6 +23,7 @@ framework that puts mobile and HTTP/2 first. grpcio is built on [gRPC Core] and
#![allow(clippy::new_without_default)]
#![allow(clippy::cast_lossless)]
#![allow(clippy::option_map_unit_fn)]
+#![allow(clippy::derive_partial_eq_without_eq)]
use grpcio_sys as grpc_sys;
#[macro_use]
@@ -30,6 +33,7 @@ mod auth_context;
mod buf;
mod call;
mod channel;
+pub mod channelz;
mod client;
mod codec;
mod cq;
@@ -38,7 +42,6 @@ mod error;
mod log_util;
mod metadata;
mod quota;
-#[cfg(feature = "secure")]
mod security;
mod server;
mod task;
@@ -72,11 +75,7 @@ pub use crate::error::{Error, Result};
pub use crate::log_util::redirect_log;
pub use crate::metadata::{Metadata, MetadataBuilder, MetadataIter};
pub use crate::quota::ResourceQuota;
-#[cfg(feature = "secure")]
-pub use crate::security::{
- CertificateRequestType, ChannelCredentials, ChannelCredentialsBuilder, ServerCredentials,
- ServerCredentialsBuilder, ServerCredentialsFetcher,
-};
+pub use crate::security::*;
pub use crate::server::{
CheckResult, Server, ServerBuilder, ServerChecker, Service, ServiceBuilder, ShutdownFuture,
};
diff --git a/src/log_util.rs b/src/log_util.rs
index 3a0cfd6..974f27a 100644
--- a/src/log_util.rs
+++ b/src/log_util.rs
@@ -28,7 +28,7 @@ extern "C" fn delegate(c_args: *mut gpr_log_func_args) {
let msg = unsafe { CStr::from_ptr(args.message).to_string_lossy() };
log::logger().log(
&Record::builder()
- .args(format_args!("{}", msg))
+ .args(format_args!("{msg}"))
.level(level)
.file(file_str.into())
.line(line.into())
diff --git a/src/metadata.rs b/src/metadata.rs
index caaebc8..fc925de 100644
--- a/src/metadata.rs
+++ b/src/metadata.rs
@@ -2,6 +2,7 @@
use crate::grpc_sys::{self, grpc_metadata, grpc_metadata_array};
use std::borrow::Cow;
+use std::fmt;
use std::mem::ManuallyDrop;
use std::{mem, slice, str};
@@ -29,7 +30,7 @@ fn normalize_key(key: &str, binary: bool) -> Result<Cow<'_, str>> {
{
continue;
}
- return Err(Error::InvalidMetadata(format!("key {:?} is invalid", key)));
+ return Err(Error::InvalidMetadata(format!("key {key:?} is invalid")));
}
let key = if is_upper_case {
Cow::Owned(key.to_ascii_lowercase())
@@ -146,7 +147,7 @@ impl MetadataBuilder {
///
/// Metadata value can be ascii string or bytes. They are distinguish by the
/// key suffix, key of bytes value should have suffix '-bin'.
-#[repr(C)]
+#[repr(transparent)]
pub struct Metadata(grpc_metadata_array);
impl Metadata {
@@ -235,6 +236,17 @@ impl Metadata {
}
}
+impl fmt::Debug for Metadata {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_map()
+ .entries(
+ self.iter()
+ .map(|(k, v)| (k, std::str::from_utf8(v).unwrap_or("?"))),
+ )
+ .finish()
+ }
+}
+
impl Clone for Metadata {
fn clone(&self) -> Metadata {
let mut builder = MetadataBuilder::with_capacity(self.len());
@@ -255,6 +267,39 @@ impl Drop for Metadata {
}
unsafe impl Send for Metadata {}
+unsafe impl Sync for Metadata {}
+
+/// A special metadata that only for receiving metadata from remote.
+///
+/// gRPC C Core manages metadata internally, it's unsafe to read them unless
+/// call is not destroyed.
+#[repr(transparent)]
+pub struct UnownedMetadata(grpc_metadata_array);
+
+impl UnownedMetadata {
+ #[inline]
+ pub fn empty() -> UnownedMetadata {
+ unsafe { mem::transmute(Metadata::with_capacity(0)) }
+ }
+ #[inline]
+ pub unsafe fn assume_valid(&self) -> &Metadata {
+ mem::transmute(self)
+ }
+
+ pub fn as_mut_ptr(&mut self) -> *mut grpc_metadata_array {
+ &mut self.0 as _
+ }
+}
+
+impl Drop for UnownedMetadata {
+ #[inline]
+ fn drop(&mut self) {
+ unsafe { grpcio_sys::grpcwrap_metadata_array_destroy_metadata_only(&mut self.0) }
+ }
+}
+
+unsafe impl Send for UnownedMetadata {}
+unsafe impl Sync for UnownedMetadata {}
/// Immutable metadata iterator
///
@@ -327,14 +372,14 @@ mod tests {
let mut builder = MetadataBuilder::new();
let mut meta_kvs = vec![];
for i in 0..5 {
- let key = format!("K{}", i);
- let val = format!("v{}", i);
+ let key = format!("K{i}");
+ let val = format!("v{i}");
builder.add_str(&key, &val).unwrap();
meta_kvs.push((key.to_ascii_lowercase(), val.into_bytes()));
}
for i in 5..10 {
- let key = format!("k{}-Bin", i);
- let val = format!("v{}", i);
+ let key = format!("k{i}-Bin");
+ let val = format!("v{i}");
builder.add_bytes(&key, val.as_bytes()).unwrap();
meta_kvs.push((key.to_ascii_lowercase(), val.into_bytes()));
}
diff --git a/src/quota.rs b/src/quota.rs
index 7891312..7ef296d 100644
--- a/src/quota.rs
+++ b/src/quota.rs
@@ -40,6 +40,15 @@ impl ResourceQuota {
}
}
+impl Clone for ResourceQuota {
+ fn clone(&self) -> Self {
+ unsafe {
+ grpc_sys::grpc_resource_quota_ref(self.raw);
+ }
+ Self { raw: self.raw }
+ }
+}
+
impl Drop for ResourceQuota {
fn drop(&mut self) {
unsafe {
diff --git a/src/security/credentials.rs b/src/security/credentials.rs
index 7d73009..1627fa2 100644
--- a/src/security/credentials.rs
+++ b/src/security/credentials.rs
@@ -8,9 +8,9 @@ use crate::error::{Error, Result};
use crate::grpc_sys::grpc_ssl_certificate_config_reload_status::{self, *};
use crate::grpc_sys::grpc_ssl_client_certificate_request_type::*;
use crate::grpc_sys::{
- self, grpc_channel_credentials, grpc_server_credentials,
- grpc_ssl_client_certificate_request_type, grpc_ssl_server_certificate_config,
+ self, grpc_ssl_client_certificate_request_type, grpc_ssl_server_certificate_config,
};
+use crate::{ChannelCredentials, ServerCredentials};
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
@@ -102,7 +102,7 @@ pub(crate) unsafe extern "C" fn server_cert_fetcher_wrapper(
panic!("fetcher user_data must be set up!");
}
let f: &mut dyn ServerCredentialsFetcher =
- (&mut *(user_data as *mut Box<dyn ServerCredentialsFetcher>)).as_mut();
+ (*(user_data as *mut Box<dyn ServerCredentialsFetcher>)).as_mut();
let result = f.fetch();
match result {
Ok(Some(builder)) => {
@@ -178,22 +178,21 @@ impl ServerCredentialsBuilder {
self.key_cert_pairs.len(),
);
if !root_cert.is_null() {
- CString::from_raw(root_cert);
+ drop(CString::from_raw(root_cert));
}
cfg
}
/// Finalize the [`ServerCredentialsBuilder`] and build the [`ServerCredentials`].
pub fn build(self) -> ServerCredentials {
- let credentials = unsafe {
+ unsafe {
let opt = grpcio_sys::grpc_ssl_server_credentials_create_options_using_config(
self.cer_request_type.to_native(),
self.build_config(),
);
- grpcio_sys::grpc_ssl_server_credentials_create_with_options(opt)
- };
-
- ServerCredentials { creds: credentials }
+ let credentials = grpcio_sys::grpc_ssl_server_credentials_create_with_options(opt);
+ ServerCredentials::from_raw(credentials)
+ }
}
}
@@ -201,7 +200,7 @@ impl Drop for ServerCredentialsBuilder {
fn drop(&mut self) {
for pair in self.key_cert_pairs.drain(..) {
unsafe {
- CString::from_raw(pair.cert_chain as *mut _);
+ drop(CString::from_raw(pair.cert_chain as *mut _));
let s = CString::from_raw(pair.private_key as *mut _);
clear_key_securely(&mut s.into_bytes_with_nul());
}
@@ -209,29 +208,28 @@ impl Drop for ServerCredentialsBuilder {
}
}
-/// Server-side SSL credentials.
-///
-/// Use [`ServerCredentialsBuilder`] to build a [`ServerCredentials`].
-pub struct ServerCredentials {
- creds: *mut grpc_server_credentials,
-}
-
-unsafe impl Send for ServerCredentials {}
-
impl ServerCredentials {
- pub(crate) unsafe fn frow_raw(creds: *mut grpc_server_credentials) -> ServerCredentials {
- ServerCredentials { creds }
- }
-
- pub fn as_mut_ptr(&mut self) -> *mut grpc_server_credentials {
- self.creds
- }
-}
-
-impl Drop for ServerCredentials {
- fn drop(&mut self) {
+ /// Creates the credentials using a certificate config fetcher. Use this
+ /// method to reload the certificates and keys of the SSL server without
+ /// interrupting the operation of the server. Initial certificate config will be
+ /// fetched during server initialization.
+ pub fn with_fetcher(
+ fetcher: Box<dyn ServerCredentialsFetcher + Send + Sync>,
+ cer_request_type: CertificateRequestType,
+ ) -> Self {
+ let fetcher_wrap = Box::new(fetcher);
+ let fetcher_wrap_ptr = Box::into_raw(fetcher_wrap);
unsafe {
- grpc_sys::grpc_server_credentials_release(self.creds);
+ let opt = grpcio_sys::grpc_ssl_server_credentials_create_options_using_config_fetcher(
+ cer_request_type.to_native(),
+ Some(server_cert_fetcher_wrapper),
+ fetcher_wrap_ptr as _,
+ );
+ let mut creds = ServerCredentials::from_raw(
+ grpcio_sys::grpc_ssl_server_credentials_create_with_options(opt),
+ );
+ creds._fetcher = Some(Box::from_raw(fetcher_wrap_ptr));
+ creds
}
}
}
@@ -331,19 +329,7 @@ impl Drop for ChannelCredentialsBuilder {
}
}
-/// Client-side SSL credentials.
-///
-/// Use [`ChannelCredentialsBuilder`] or [`ChannelCredentials::google_default_credentials`] to
-/// build a [`ChannelCredentials`].
-pub struct ChannelCredentials {
- creds: *mut grpc_channel_credentials,
-}
-
impl ChannelCredentials {
- pub fn as_mut_ptr(&mut self) -> *mut grpc_channel_credentials {
- self.creds
- }
-
/// Try to build a [`ChannelCredentials`] to authenticate with Google OAuth credentials.
pub fn google_default_credentials() -> Result<ChannelCredentials> {
// Initialize the runtime here. Because this is an associated method
@@ -360,9 +346,3 @@ impl ChannelCredentials {
}
}
}
-
-impl Drop for ChannelCredentials {
- fn drop(&mut self) {
- unsafe { grpc_sys::grpc_channel_credentials_release(self.creds) }
- }
-}
diff --git a/src/security/mod.rs b/src/security/mod.rs
index f2c2bad..c935461 100644
--- a/src/security/mod.rs
+++ b/src/security/mod.rs
@@ -1,10 +1,81 @@
// Copyright 2020 TiKV Project Authors. Licensed under Apache-2.0.
+#[cfg(feature = "_secure")]
mod credentials;
+use grpcio_sys::{grpc_channel_credentials, grpc_server_credentials};
+
+#[cfg(feature = "_secure")]
pub use self::credentials::{
- CertificateRequestType, ChannelCredentials, ChannelCredentialsBuilder, ServerCredentials,
- ServerCredentialsBuilder, ServerCredentialsFetcher,
+ CertificateRequestType, ChannelCredentialsBuilder, ServerCredentialsBuilder,
+ ServerCredentialsFetcher,
};
-pub(crate) use self::credentials::server_cert_fetcher_wrapper;
+/// Client-side SSL credentials.
+///
+/// Use [`ChannelCredentialsBuilder`] or [`ChannelCredentials::google_default_credentials`] to
+/// build a [`ChannelCredentials`].
+pub struct ChannelCredentials {
+ creds: *mut grpc_channel_credentials,
+}
+
+impl ChannelCredentials {
+ pub fn as_mut_ptr(&mut self) -> *mut grpc_channel_credentials {
+ self.creds
+ }
+
+ /// Creates an insecure channel credentials object.
+ pub fn insecure() -> ChannelCredentials {
+ unsafe {
+ let creds = grpcio_sys::grpc_insecure_credentials_create();
+ ChannelCredentials { creds }
+ }
+ }
+}
+
+impl Drop for ChannelCredentials {
+ fn drop(&mut self) {
+ unsafe { grpcio_sys::grpc_channel_credentials_release(self.creds) }
+ }
+}
+
+/// Server-side SSL credentials.
+///
+/// Use [`ServerCredentialsBuilder`] to build a [`ServerCredentials`].
+pub struct ServerCredentials {
+ creds: *mut grpc_server_credentials,
+ // Double allocation to get around C call.
+ #[cfg(feature = "_secure")]
+ _fetcher: Option<Box<Box<dyn crate::ServerCredentialsFetcher + Send + Sync>>>,
+}
+
+unsafe impl Send for ServerCredentials {}
+
+impl ServerCredentials {
+ /// Creates an insecure server credentials object.
+ pub fn insecure() -> ServerCredentials {
+ unsafe {
+ let creds = grpcio_sys::grpc_insecure_server_credentials_create();
+ ServerCredentials::from_raw(creds)
+ }
+ }
+ pub(crate) unsafe fn from_raw(creds: *mut grpc_server_credentials) -> ServerCredentials {
+ ServerCredentials {
+ creds,
+ #[cfg(feature = "_secure")]
+ _fetcher: None,
+ }
+ }
+
+ pub fn as_mut_ptr(&mut self) -> *mut grpc_server_credentials {
+ self.creds
+ }
+}
+
+impl Drop for ServerCredentials {
+ fn drop(&mut self) {
+ unsafe {
+ grpcio_sys::grpc_server_credentials_release(self.creds);
+ }
+ }
+}
diff --git a/src/server.rs b/src/server.rs
index f2150a0..042217e 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -2,17 +2,17 @@
use std::cell::UnsafeCell;
use std::collections::HashMap;
+use std::ffi::CString;
use std::fmt::{self, Debug, Formatter};
-use std::net::{IpAddr, SocketAddr};
+use std::future::Future;
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
+use std::task::{Context, Poll};
use crate::grpc_sys::{self, grpc_call_error, grpc_server};
-use futures::future::Future;
-use futures::ready;
-use futures::task::{Context, Poll};
+use futures_util::ready;
use crate::call::server::*;
use crate::call::{MessageReader, Method, MethodType};
@@ -21,8 +21,8 @@ use crate::cq::CompletionQueue;
use crate::env::Environment;
use crate::error::{Error, Result};
use crate::task::{CallTag, CqFuture};
-use crate::RpcContext;
use crate::RpcStatus;
+use crate::{RpcContext, ServerCredentials};
const DEFAULT_REQUEST_SLOTS_PER_CQ: usize = 1024;
@@ -65,106 +65,6 @@ where
}
}
-/// Given a host and port, creates a string of the form "host:port" or
-/// "[host]:port", depending on whether the host is an IPv6 literal.
-fn join_host_port(host: &str, port: u16) -> String {
- if host.starts_with("unix:") | host.starts_with("unix-abstract:") {
- format!("{}\0", host)
- } else if let Ok(ip) = host.parse::<IpAddr>() {
- format!("{}\0", SocketAddr::new(ip, port))
- } else {
- format!("{}:{}\0", host, port)
- }
-}
-
-#[cfg(feature = "secure")]
-mod imp {
- use super::join_host_port;
- use crate::grpc_sys::{self, grpc_server};
- use crate::security::ServerCredentialsFetcher;
- use crate::ServerCredentials;
-
- pub struct Binder {
- pub host: String,
- pub port: u16,
- cred: Option<ServerCredentials>,
- // Double allocation to get around C call.
- #[allow(clippy::redundant_allocation)]
- _fetcher: Option<Box<Box<dyn ServerCredentialsFetcher + Send + Sync>>>,
- }
-
- impl Binder {
- pub fn new(host: String, port: u16) -> Binder {
- let cred = None;
- Binder {
- host,
- port,
- cred,
- _fetcher: None,
- }
- }
-
- #[allow(clippy::redundant_allocation)]
- pub fn with_cred(
- host: String,
- port: u16,
- cred: ServerCredentials,
- _fetcher: Option<Box<Box<dyn ServerCredentialsFetcher + Send + Sync>>>,
- ) -> Binder {
- let cred = Some(cred);
- Binder {
- host,
- port,
- cred,
- _fetcher,
- }
- }
-
- pub unsafe fn bind(&mut self, server: *mut grpc_server) -> u16 {
- let addr = join_host_port(&self.host, self.port);
- let port = match self.cred.take() {
- None => grpc_sys::grpc_server_add_insecure_http2_port(server, addr.as_ptr() as _),
- Some(mut cert) => grpc_sys::grpc_server_add_secure_http2_port(
- server,
- addr.as_ptr() as _,
- cert.as_mut_ptr(),
- ),
- };
- port as u16
- }
- }
-}
-
-#[cfg(not(feature = "secure"))]
-mod imp {
- use super::join_host_port;
- use crate::grpc_sys::{self, grpc_server};
-
- pub struct Binder {
- pub host: String,
- pub port: u16,
- }
-
- impl Binder {
- pub fn new(host: String, port: u16) -> Binder {
- Binder { host, port }
- }
-
- pub unsafe fn bind(&mut self, server: *mut grpc_server) -> u16 {
- let addr = join_host_port(&self.host, self.port);
- grpc_sys::grpc_server_add_insecure_http2_port(server, addr.as_ptr() as _) as u16
- }
- }
-}
-
-use self::imp::Binder;
-
-impl Debug for Binder {
- fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
- write!(f, "Binder {{ host: {}, port: {} }}", self.host, self.port)
- }
-}
-
/// [`Service`] factory in order to configure the properties.
///
/// Use it to build a service which can be registered to a server.
@@ -299,7 +199,6 @@ pub struct Service {
/// [`Server`] factory in order to configure the properties.
pub struct ServerBuilder {
env: Arc<Environment>,
- binders: Vec<Binder>,
args: Option<ChannelArgs>,
slots_per_cq: usize,
handlers: HashMap<&'static [u8], BoxHandler>,
@@ -311,7 +210,6 @@ impl ServerBuilder {
pub fn new(env: Arc<Environment>) -> ServerBuilder {
ServerBuilder {
env,
- binders: Vec::new(),
args: None,
slots_per_cq: DEFAULT_REQUEST_SLOTS_PER_CQ,
handlers: HashMap::new(),
@@ -319,14 +217,6 @@ impl ServerBuilder {
}
}
- /// Bind to an address.
- ///
- /// This function can be called multiple times to bind to multiple ports.
- pub fn bind<S: Into<String>>(mut self, host: S, port: u16) -> ServerBuilder {
- self.binders.push(Binder::new(host.into(), port));
- self
- }
-
/// Add additional configuration for each incoming channel.
pub fn channel_args(mut self, args: ChannelArgs) -> ServerBuilder {
self.args = Some(args);
@@ -356,22 +246,13 @@ impl ServerBuilder {
}
/// Finalize the [`ServerBuilder`] and build the [`Server`].
- pub fn build(mut self) -> Result<Server> {
+ pub fn build(self) -> Result<Server> {
let args = self
.args
.as_ref()
.map_or_else(ptr::null, ChannelArgs::as_ptr);
unsafe {
let server = grpc_sys::grpc_server_create(args, ptr::null_mut());
- for binder in self.binders.iter_mut() {
- let bind_port = binder.bind(server);
- if bind_port == 0 {
- grpc_sys::grpc_server_destroy(server);
- return Err(Error::BindFail(binder.host.clone(), binder.port));
- }
- binder.port = bind_port;
- }
-
for cq in self.env.completion_queues() {
let cq_ref = cq.borrow()?;
grpc_sys::grpc_server_register_completion_queue(
@@ -385,8 +266,8 @@ impl ServerBuilder {
env: self.env,
core: Arc::new(ServerCore {
server,
+ creds: Mutex::new(Vec::new()),
shutdown: AtomicBool::new(false),
- binders: self.binders,
slots_per_cq: self.slots_per_cq,
}),
handlers: self.handlers,
@@ -396,66 +277,9 @@ impl ServerBuilder {
}
}
-#[cfg(feature = "secure")]
-mod secure_server {
- use super::{Binder, ServerBuilder};
- use crate::grpc_sys;
- use crate::security::{
- server_cert_fetcher_wrapper, CertificateRequestType, ServerCredentials,
- ServerCredentialsFetcher,
- };
-
- impl ServerBuilder {
- /// Bind to an address with credentials for secure connection.
- ///
- /// This function can be called multiple times to bind to multiple ports.
- pub fn bind_with_cred<S: Into<String>>(
- mut self,
- host: S,
- port: u16,
- c: ServerCredentials,
- ) -> ServerBuilder {
- self.binders
- .push(Binder::with_cred(host.into(), port, c, None));
- self
- }
-
- /// Bind to an address for secure connection.
- ///
- /// The required credentials will be fetched using provided `fetcher`. This
- /// function can be called multiple times to bind to multiple ports.
- pub fn bind_with_fetcher<S: Into<String>>(
- mut self,
- host: S,
- port: u16,
- fetcher: Box<dyn ServerCredentialsFetcher + Send + Sync>,
- cer_request_type: CertificateRequestType,
- ) -> ServerBuilder {
- let fetcher_wrap = Box::new(fetcher);
- let fetcher_wrap_ptr = Box::into_raw(fetcher_wrap);
- let (sc, fb) = unsafe {
- let opt = grpc_sys::grpc_ssl_server_credentials_create_options_using_config_fetcher(
- cer_request_type.to_native(),
- Some(server_cert_fetcher_wrapper),
- fetcher_wrap_ptr as _,
- );
- (
- ServerCredentials::frow_raw(
- grpcio_sys::grpc_ssl_server_credentials_create_with_options(opt),
- ),
- Box::from_raw(fetcher_wrap_ptr),
- )
- };
- self.binders
- .push(Binder::with_cred(host.into(), port, sc, Some(fb)));
- self
- }
- }
-}
-
struct ServerCore {
server: *mut grpc_server,
- binders: Vec<Binder>,
+ creds: Mutex<Vec<ServerCredentials>>,
slots_per_cq: usize,
shutdown: AtomicBool,
}
@@ -494,6 +318,7 @@ impl RequestCallContext {
// Apparently, its life time is guaranteed by the ref count, hence is safe to be sent
// to other thread. However it's not `Sync`, as `BoxHandler` is unnecessarily `Sync`.
+#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl Send for RequestCallContext {}
/// Request notification of a new call.
@@ -520,7 +345,7 @@ pub fn request_call(ctx: RequestCallContext, cq: &CompletionQueue) {
)
};
if code != grpc_call_error::GRPC_CALL_OK {
- Box::from(tag);
+ drop(Box::from(tag));
panic!("failed to request call: {:?}", code);
}
}
@@ -605,9 +430,34 @@ impl Server {
}
}
- /// Get binded addresses pairs.
- pub fn bind_addrs(&self) -> impl ExactSizeIterator<Item = (&String, u16)> {
- self.core.binders.iter().map(|b| (&b.host, b.port))
+ /// Try binding the server to the given `addr` endpoint (eg, localhost:1234,
+ /// 192.168.1.1:31416, [::1]:27182, etc.).
+ ///
+ /// It can be invoked multiple times. Should be used before starting the server.
+ ///
+ /// # Return
+ ///
+ /// The bound port is returned on success.
+ pub fn add_listening_port(
+ &mut self,
+ addr: impl Into<String>,
+ mut creds: ServerCredentials,
+ ) -> Result<u16> {
+ // There is no Null in UTF-8 string.
+ let addr = CString::new(addr.into()).unwrap();
+ let port = unsafe {
+ grpcio_sys::grpc_server_add_http2_port(
+ self.core.server,
+ addr.as_ptr() as _,
+ creds.as_mut_ptr(),
+ ) as u16
+ };
+ if port != 0 {
+ self.core.creds.lock().unwrap().push(creds);
+ Ok(port)
+ } else {
+ Err(Error::BindFail(addr))
+ }
}
/// Add an rpc channel for an established connection represented as a file
@@ -620,8 +470,9 @@ impl Server {
/// this call, the socket must not be accessed (read / written / closed)
/// by other code.
#[cfg(unix)]
- pub unsafe fn add_insecure_channel_from_fd(&self, fd: ::std::os::raw::c_int) {
- grpc_sys::grpc_server_add_insecure_channel_from_fd(self.core.server, ptr::null_mut(), fd)
+ pub unsafe fn add_channel_from_fd(&mut self, fd: ::std::os::raw::c_int) {
+ let mut creds = ServerCredentials::insecure();
+ grpcio_sys::grpc_server_add_channel_from_fd(self.core.server, fd, creds.as_mut_ptr())
}
}
@@ -635,35 +486,17 @@ impl Drop for Server {
None
};
self.cancel_all_calls();
- let _ = f.map(futures::executor::block_on);
+ let _ = f.map(futures_executor::block_on);
}
}
impl Debug for Server {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
- write!(f, "Server {:?}", self.core.binders)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::join_host_port;
-
- #[test]
- fn test_join_host_port() {
- let tbl = vec![
- ("localhost", 0u16, "localhost:0\0"),
- ("127.0.0.1", 100u16, "127.0.0.1:100\0"),
- ("::1", 0u16, "[::1]:0\0"),
- (
- "fe80::7376:45d5:fb08:61e3",
- 10028u16,
- "[fe80::7376:45d5:fb08:61e3]:10028\0",
- ),
- ];
-
- for (h, p, e) in &tbl {
- assert_eq!(join_host_port(h, *p), e.to_owned());
- }
+ write!(
+ f,
+ "Server {{ handlers: {}, checkers: {} }}",
+ self.handlers.len(),
+ self.checkers.len()
+ )
}
}
diff --git a/src/task/executor.rs b/src/task/executor.rs
index 4a13905..3941005 100644
--- a/src/task/executor.rs
+++ b/src/task/executor.rs
@@ -8,12 +8,13 @@
//! same completion queue as its inner call. Hence method `Executor::spawn` is provided.
use std::cell::UnsafeCell;
+use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
+use std::task::{Context, Poll};
-use futures::future::Future;
-use futures::task::{waker_ref, ArcWake, Context, Poll};
+use futures_util::task::{waker_ref, ArcWake};
use super::CallTag;
use crate::call::Call;
diff --git a/src/task/mod.rs b/src/task/mod.rs
index 53369f1..d1827fc 100644
--- a/src/task/mod.rs
+++ b/src/task/mod.rs
@@ -5,23 +5,24 @@ mod executor;
mod promise;
use std::fmt::{self, Debug, Formatter};
+use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
+use std::task::{Context, Poll, Waker};
-use futures::future::Future;
-use futures::task::{Context, Poll, Waker};
use parking_lot::Mutex;
use self::callback::{Abort, Request as RequestCallback, UnaryRequest as UnaryRequestCallback};
use self::executor::SpawnTask;
use self::promise::{Action as ActionPromise, Batch as BatchPromise};
use crate::call::server::RequestContext;
-use crate::call::{BatchContext, Call, MessageReader};
+use crate::call::{BatchContext, Call};
use crate::cq::CompletionQueue;
use crate::error::{Error, Result};
use crate::server::RequestCallContext;
pub(crate) use self::executor::{Executor, Kicker, UnfinishedWork};
+pub(crate) use self::promise::BatchResult;
pub use self::promise::BatchType;
/// A handle that is used to notify future that the task finishes.
@@ -104,7 +105,7 @@ impl<T> Future for CqFuture<T> {
}
/// Future object for batch jobs.
-pub type BatchFuture = CqFuture<Option<MessageReader>>;
+pub type BatchFuture = CqFuture<BatchResult>;
/// A result holder for asynchronous execution.
// This enum is going to be passed to FFI, so don't use trait or generic here.
@@ -185,7 +186,7 @@ impl CallTag {
impl Debug for CallTag {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match *self {
- CallTag::Batch(ref ctx) => write!(f, "CallTag::Batch({:?})", ctx),
+ CallTag::Batch(ref ctx) => write!(f, "CallTag::Batch({ctx:?})"),
CallTag::Request(_) => write!(f, "CallTag::Request(..)"),
CallTag::UnaryRequest(_) => write!(f, "CallTag::UnaryRequest(..)"),
CallTag::Abort(_) => write!(f, "CallTag::Abort(..)"),
@@ -203,7 +204,7 @@ mod tests {
use super::*;
use crate::env::Environment;
- use futures::executor::block_on;
+ use futures_executor::block_on;
#[test]
fn test_resolve() {
diff --git a/src/task/promise.rs b/src/task/promise.rs
index 2d826d4..e9b3646 100644
--- a/src/task/promise.rs
+++ b/src/task/promise.rs
@@ -6,6 +6,7 @@ use std::sync::Arc;
use super::Inner;
use crate::call::{BatchContext, MessageReader, RpcStatusCode};
use crate::error::Error;
+use crate::metadata::UnownedMetadata;
/// Batch job type.
#[derive(PartialEq, Debug)]
@@ -18,15 +19,46 @@ pub enum BatchType {
CheckRead,
}
+/// A promise result which stores a message reader with bundled metadata.
+pub struct BatchResult {
+ pub message_reader: Option<MessageReader>,
+ pub initial_metadata: UnownedMetadata,
+ pub trailing_metadata: UnownedMetadata,
+}
+
+impl BatchResult {
+ pub fn new(
+ message_reader: Option<MessageReader>,
+ initial_metadata: Option<UnownedMetadata>,
+ trailing_metadata: Option<UnownedMetadata>,
+ ) -> BatchResult {
+ let initial_metadata = if let Some(m) = initial_metadata {
+ m
+ } else {
+ UnownedMetadata::empty()
+ };
+ let trailing_metadata = if let Some(m) = trailing_metadata {
+ m
+ } else {
+ UnownedMetadata::empty()
+ };
+ BatchResult {
+ message_reader,
+ initial_metadata,
+ trailing_metadata,
+ }
+ }
+}
+
/// A promise used to resolve batch jobs.
pub struct Batch {
ty: BatchType,
ctx: BatchContext,
- inner: Arc<Inner<Option<MessageReader>>>,
+ inner: Arc<Inner<BatchResult>>,
}
impl Batch {
- pub fn new(ty: BatchType, inner: Arc<Inner<Option<MessageReader>>>) -> Batch {
+ pub fn new(ty: BatchType, inner: Arc<Inner<BatchResult>>) -> Batch {
Batch {
ty,
ctx: BatchContext::new(),
@@ -42,11 +74,11 @@ impl Batch {
let task = {
let mut guard = self.inner.lock();
if success {
- guard.set_result(Ok(self.ctx.recv_message()))
+ guard.set_result(Ok(BatchResult::new(self.ctx.recv_message(), None, None)))
} else {
// rely on C core to handle the failed read (e.g. deliver approriate
// statusCode on the clientside).
- guard.set_result(Ok(None))
+ guard.set_result(Ok(BatchResult::new(None, None, None)))
}
};
task.map(|t| t.wake());
@@ -58,7 +90,11 @@ impl Batch {
if succeed {
let status = self.ctx.rpc_status();
if status.code() == RpcStatusCode::OK {
- guard.set_result(Ok(None))
+ guard.set_result(Ok(BatchResult::new(
+ None,
+ Some(self.ctx.take_initial_metadata()),
+ Some(self.ctx.take_trailing_metadata()),
+ )))
} else {
guard.set_result(Err(Error::RpcFailure(status)))
}
@@ -74,7 +110,11 @@ impl Batch {
let mut guard = self.inner.lock();
let status = self.ctx.rpc_status();
if status.code() == RpcStatusCode::OK {
- guard.set_result(Ok(self.ctx.recv_message()))
+ guard.set_result(Ok(BatchResult::new(
+ self.ctx.recv_message(),
+ Some(self.ctx.take_initial_metadata()),
+ Some(self.ctx.take_trailing_metadata()),
+ )))
} else {
guard.set_result(Err(Error::RpcFailure(status)))
}