summaryrefslogtreecommitdiff
path: root/src/visit.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/visit.rs')
-rw-r--r--src/visit.rs187
1 files changed, 187 insertions, 0 deletions
diff --git a/src/visit.rs b/src/visit.rs
new file mode 100644
index 0000000..d63063a
--- /dev/null
+++ b/src/visit.rs
@@ -0,0 +1,187 @@
+use proc_macro2::TokenStream;
+use quote::quote;
+use syn::{
+ visit_mut::{self, visit_item_mut, visit_path_segment_mut, VisitMut},
+ Expr, ExprBlock, File, GenericArgument, GenericParam, Item, PathArguments, PathSegment, Stmt,
+ Type, TypeParamBound, WherePredicate,
+};
+
+pub struct ReplaceGenericType<'a> {
+ generic_type: &'a str,
+ arg_type: &'a PathSegment,
+}
+
+impl<'a> ReplaceGenericType<'a> {
+ pub fn new(generic_type: &'a str, arg_type: &'a PathSegment) -> Self {
+ Self {
+ generic_type,
+ arg_type,
+ }
+ }
+
+ pub fn replace_generic_type(item: &mut Item, generic_type: &'a str, arg_type: &'a PathSegment) {
+ let mut s = Self::new(generic_type, arg_type);
+ s.visit_item_mut(item);
+ }
+}
+
+impl<'a> VisitMut for ReplaceGenericType<'a> {
+ fn visit_item_mut(&mut self, i: &mut Item) {
+ if let Item::Fn(item_fn) = i {
+ // remove generic type from generics <T, F>
+ let args = item_fn
+ .sig
+ .generics
+ .params
+ .iter()
+ .filter_map(|param| {
+ if let GenericParam::Type(type_param) = &param {
+ if type_param.ident.to_string().eq(self.generic_type) {
+ None
+ } else {
+ Some(param)
+ }
+ } else {
+ Some(param)
+ }
+ })
+ .collect::<Vec<_>>();
+ item_fn.sig.generics.params = args.into_iter().cloned().collect();
+
+ // remove generic type from where clause
+ if let Some(where_clause) = &mut item_fn.sig.generics.where_clause {
+ let new_where_clause = where_clause
+ .predicates
+ .iter()
+ .filter_map(|predicate| {
+ if let WherePredicate::Type(predicate_type) = predicate {
+ if let Type::Path(p) = &predicate_type.bounded_ty {
+ if p.path.segments[0].ident.to_string().eq(self.generic_type) {
+ None
+ } else {
+ Some(predicate)
+ }
+ } else {
+ Some(predicate)
+ }
+ } else {
+ Some(predicate)
+ }
+ })
+ .collect::<Vec<_>>();
+
+ where_clause.predicates = new_where_clause.into_iter().cloned().collect();
+ };
+ }
+ visit_item_mut(self, i)
+ }
+ fn visit_path_segment_mut(&mut self, i: &mut PathSegment) {
+ // replace generic type with target type
+ if i.ident.to_string().eq(&self.generic_type) {
+ *i = self.arg_type.clone();
+ }
+ visit_path_segment_mut(self, i);
+ }
+}
+
+pub struct AsyncAwaitRemoval;
+
+impl AsyncAwaitRemoval {
+ pub fn remove_async_await(&mut self, item: TokenStream) -> TokenStream {
+ let mut syntax_tree: File = syn::parse(item.into()).unwrap();
+ self.visit_file_mut(&mut syntax_tree);
+ quote!(#syntax_tree)
+ }
+}
+
+impl VisitMut for AsyncAwaitRemoval {
+ fn visit_expr_mut(&mut self, node: &mut Expr) {
+ // Delegate to the default impl to visit nested expressions.
+ visit_mut::visit_expr_mut(self, node);
+
+ match node {
+ Expr::Await(expr) => *node = (*expr.base).clone(),
+
+ Expr::Async(expr) => {
+ let inner = &expr.block;
+ let sync_expr = if let [Stmt::Expr(expr, None)] = inner.stmts.as_slice() {
+ // remove useless braces when there is only one statement
+ expr.clone()
+ } else {
+ Expr::Block(ExprBlock {
+ attrs: expr.attrs.clone(),
+ block: inner.clone(),
+ label: None,
+ })
+ };
+ *node = sync_expr;
+ }
+ _ => {}
+ }
+ }
+
+ fn visit_item_mut(&mut self, i: &mut Item) {
+ // find generic parameter of Future and replace it with its Output type
+ if let Item::Fn(item_fn) = i {
+ let mut inputs: Vec<(String, PathSegment)> = vec![];
+
+ // generic params: <T:Future<Output=()>, F>
+ for param in &item_fn.sig.generics.params {
+ // generic param: T:Future<Output=()>
+ if let GenericParam::Type(type_param) = param {
+ let generic_type_name = type_param.ident.to_string();
+
+ // bound: Future<Output=()>
+ for bound in &type_param.bounds {
+ inputs.extend(search_trait_bound(&generic_type_name, bound));
+ }
+ }
+ }
+
+ if let Some(where_clause) = &item_fn.sig.generics.where_clause {
+ for predicate in &where_clause.predicates {
+ if let WherePredicate::Type(predicate_type) = predicate {
+ let generic_type_name = if let Type::Path(p) = &predicate_type.bounded_ty {
+ p.path.segments[0].ident.to_string()
+ } else {
+ panic!("Please submit an issue");
+ };
+
+ for bound in &predicate_type.bounds {
+ inputs.extend(search_trait_bound(&generic_type_name, bound));
+ }
+ }
+ }
+ }
+
+ for (generic_type_name, path_seg) in &inputs {
+ ReplaceGenericType::replace_generic_type(i, generic_type_name, path_seg);
+ }
+ }
+ visit_item_mut(self, i);
+ }
+}
+
+fn search_trait_bound(
+ generic_type_name: &str,
+ bound: &TypeParamBound,
+) -> Vec<(String, PathSegment)> {
+ let mut inputs = vec![];
+
+ if let TypeParamBound::Trait(trait_bound) = bound {
+ let segment = &trait_bound.path.segments[trait_bound.path.segments.len() - 1];
+ let name = segment.ident.to_string();
+ if name.eq("Future") {
+ // match Future<Output=Type>
+ if let PathArguments::AngleBracketed(args) = &segment.arguments {
+ // binding: Output=Type
+ if let GenericArgument::AssocType(binding) = &args.args[0] {
+ if let Type::Path(p) = &binding.ty {
+ inputs.push((generic_type_name.to_owned(), p.path.segments[0].clone()));
+ }
+ }
+ }
+ }
+ }
+ inputs
+}