From f9da7ebe43ee4ae392b98c26f9132ea6b43079e0 Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Fri, 12 Apr 2024 19:58:59 +0200 Subject: [PATCH] Handle shutdown correctly --- Cargo.lock | 1 + Cargo.toml | 1 + build.rs | 1 + proto/controller.proto | 12 ++++++++++++ src/framework.rs | 2 ++ src/main.rs | 28 ++++++++++++++++++++++------ src/server.rs | 34 ++++++++++++++++++++++++++++++---- 7 files changed, 69 insertions(+), 10 deletions(-) create mode 100644 proto/controller.proto diff --git a/Cargo.lock b/Cargo.lock index 8f44471..92bfd2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1074,6 +1074,7 @@ dependencies = [ "time", "tokio", "tokio-stream", + "tokio-util", "tonic", "tonic-build", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 82090ff..79111b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ tempfile = "3.10.1" time = "0.3.35" tokio = { version = "1.37.0", features = ["full"] } tokio-stream = { version = "0.1.15", features = ["net"] } +tokio-util = "0.7.10" tonic = { version = "0.11.0", features = ["tls"] } tracing = "0.1.40" tracing-subscriber = "0.3.18" diff --git a/build.rs b/build.rs index 947ecc2..850a157 100644 --- a/build.rs +++ b/build.rs @@ -1,4 +1,5 @@ fn main() -> Result<(), Box> { tonic_build::compile_protos("proto/tfplugin6.6.proto")?; + tonic_build::compile_protos("proto/controller.proto")?; Ok(()) } diff --git a/proto/controller.proto b/proto/controller.proto new file mode 100644 index 0000000..8397e70 --- /dev/null +++ b/proto/controller.proto @@ -0,0 +1,12 @@ + +syntax = "proto3"; +package plugin; +option go_package = "./plugin"; + +message Empty { +} + +// The GRPCController is responsible for telling the plugin server to shutdown. +service GRPCController { + rpc Shutdown(Empty) returns (Empty); +} diff --git a/src/framework.rs b/src/framework.rs index 0039c86..32f41e0 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + pub trait DataSource { fn schema(&self); fn read(&self) -> DResult<()>; diff --git a/src/main.rs b/src/main.rs index 0932840..dee01a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod cert; +mod framework; mod server; mod values; -mod framework; use std::{env, path::PathBuf}; @@ -43,15 +43,31 @@ async fn main() -> eyre::Result<()> { let uds = UnixListener::bind(socket).wrap_err("failed to bind unix listener")?; let uds_stream = tokio_stream::wrappers::UnixListenerStream::new(uds); - tonic::transport::Server::builder() + let token = tokio_util::sync::CancellationToken::new(); + + let server = tonic::transport::Server::builder() .tls_config(tls) .wrap_err("invalid TLS config")? .add_service(server::tfplugin6::provider_server::ProviderServer::new( - server::MyProvider, + server::MyProvider { + shutdown: token.clone(), + }, )) - .serve_with_incoming(uds_stream) - .await - .wrap_err("failed to start server")?; + .add_service( + server::plugin::grpc_controller_server::GrpcControllerServer::new( + server::MyController { + shutdown: token.clone(), + }, + ), + ) + .serve_with_incoming(uds_stream); + + tokio::select! { + _ = token.cancelled() => {} + result = server => { + result.wrap_err("failed to start server")?; + } + } Ok(()) } diff --git a/src/server.rs b/src/server.rs index 5267ed3..b23ad7f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,19 +4,27 @@ pub mod tfplugin6 { tonic::include_proto!("tfplugin6"); } +pub mod plugin { + tonic::include_proto!("plugin"); +} + use std::{ collections::{BTreeMap, HashMap}, + sync::Mutex, vec, }; use tfplugin6::provider_server::{Provider, ProviderServer}; +use tokio_util::sync::CancellationToken; use tonic::{transport::Server, Request, Response, Result, Status}; use tracing::info; use crate::values::Type; -#[derive(Debug, Default)] -pub struct MyProvider; +#[derive(Debug)] +pub struct MyProvider { + pub shutdown: CancellationToken, +} fn empty_schema() -> tfplugin6::Schema { tfplugin6::Schema { @@ -43,6 +51,7 @@ impl Provider for MyProvider { &self, request: Request, ) -> Result, Status> { + info!("get_metadata"); Err(Status::unimplemented( "GetMetadata: Not implemeneted".to_owned(), )) @@ -250,8 +259,25 @@ impl Provider for MyProvider { &self, request: Request, ) -> Result, Status> { - tracing::error!("stop_provider"); + tracing::info!("stop_provider"); - todo!("stop_provider") + shutdown(&self.shutdown).await + } +} + +pub struct MyController { + pub shutdown: CancellationToken, +} + +async fn shutdown(token: &CancellationToken) -> ! { + token.cancel(); + std::future::poll_fn::<(), _>(|_| std::task::Poll::Pending).await; + unreachable!("we've should have gone to sleep") +} + +#[tonic::async_trait] +impl plugin::grpc_controller_server::GrpcController for MyController { + async fn shutdown(&self, request: Request) -> Result> { + shutdown(&self.shutdown).await } }