diff options
Diffstat (limited to 'src/visit.rs')
-rw-r--r-- | src/visit.rs | 187 |
1 files changed, 187 insertions, 0 deletions
diff --git a/src/visit.rs b/src/visit.rs new file mode 100644 index 0000000..d63063a --- /dev/null +++ b/src/visit.rs @@ -0,0 +1,187 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::{ + visit_mut::{self, visit_item_mut, visit_path_segment_mut, VisitMut}, + Expr, ExprBlock, File, GenericArgument, GenericParam, Item, PathArguments, PathSegment, Stmt, + Type, TypeParamBound, WherePredicate, +}; + +pub struct ReplaceGenericType<'a> { + generic_type: &'a str, + arg_type: &'a PathSegment, +} + +impl<'a> ReplaceGenericType<'a> { + pub fn new(generic_type: &'a str, arg_type: &'a PathSegment) -> Self { + Self { + generic_type, + arg_type, + } + } + + pub fn replace_generic_type(item: &mut Item, generic_type: &'a str, arg_type: &'a PathSegment) { + let mut s = Self::new(generic_type, arg_type); + s.visit_item_mut(item); + } +} + +impl<'a> VisitMut for ReplaceGenericType<'a> { + fn visit_item_mut(&mut self, i: &mut Item) { + if let Item::Fn(item_fn) = i { + // remove generic type from generics <T, F> + let args = item_fn + .sig + .generics + .params + .iter() + .filter_map(|param| { + if let GenericParam::Type(type_param) = ¶m { + if type_param.ident.to_string().eq(self.generic_type) { + None + } else { + Some(param) + } + } else { + Some(param) + } + }) + .collect::<Vec<_>>(); + item_fn.sig.generics.params = args.into_iter().cloned().collect(); + + // remove generic type from where clause + if let Some(where_clause) = &mut item_fn.sig.generics.where_clause { + let new_where_clause = where_clause + .predicates + .iter() + .filter_map(|predicate| { + if let WherePredicate::Type(predicate_type) = predicate { + if let Type::Path(p) = &predicate_type.bounded_ty { + if p.path.segments[0].ident.to_string().eq(self.generic_type) { + None + } else { + Some(predicate) + } + } else { + Some(predicate) + } + } else { + Some(predicate) + } + }) + .collect::<Vec<_>>(); + + where_clause.predicates = new_where_clause.into_iter().cloned().collect(); + }; + } + visit_item_mut(self, i) + } + fn visit_path_segment_mut(&mut self, i: &mut PathSegment) { + // replace generic type with target type + if i.ident.to_string().eq(&self.generic_type) { + *i = self.arg_type.clone(); + } + visit_path_segment_mut(self, i); + } +} + +pub struct AsyncAwaitRemoval; + +impl AsyncAwaitRemoval { + pub fn remove_async_await(&mut self, item: TokenStream) -> TokenStream { + let mut syntax_tree: File = syn::parse(item.into()).unwrap(); + self.visit_file_mut(&mut syntax_tree); + quote!(#syntax_tree) + } +} + +impl VisitMut for AsyncAwaitRemoval { + fn visit_expr_mut(&mut self, node: &mut Expr) { + // Delegate to the default impl to visit nested expressions. + visit_mut::visit_expr_mut(self, node); + + match node { + Expr::Await(expr) => *node = (*expr.base).clone(), + + Expr::Async(expr) => { + let inner = &expr.block; + let sync_expr = if let [Stmt::Expr(expr, None)] = inner.stmts.as_slice() { + // remove useless braces when there is only one statement + expr.clone() + } else { + Expr::Block(ExprBlock { + attrs: expr.attrs.clone(), + block: inner.clone(), + label: None, + }) + }; + *node = sync_expr; + } + _ => {} + } + } + + fn visit_item_mut(&mut self, i: &mut Item) { + // find generic parameter of Future and replace it with its Output type + if let Item::Fn(item_fn) = i { + let mut inputs: Vec<(String, PathSegment)> = vec![]; + + // generic params: <T:Future<Output=()>, F> + for param in &item_fn.sig.generics.params { + // generic param: T:Future<Output=()> + if let GenericParam::Type(type_param) = param { + let generic_type_name = type_param.ident.to_string(); + + // bound: Future<Output=()> + for bound in &type_param.bounds { + inputs.extend(search_trait_bound(&generic_type_name, bound)); + } + } + } + + if let Some(where_clause) = &item_fn.sig.generics.where_clause { + for predicate in &where_clause.predicates { + if let WherePredicate::Type(predicate_type) = predicate { + let generic_type_name = if let Type::Path(p) = &predicate_type.bounded_ty { + p.path.segments[0].ident.to_string() + } else { + panic!("Please submit an issue"); + }; + + for bound in &predicate_type.bounds { + inputs.extend(search_trait_bound(&generic_type_name, bound)); + } + } + } + } + + for (generic_type_name, path_seg) in &inputs { + ReplaceGenericType::replace_generic_type(i, generic_type_name, path_seg); + } + } + visit_item_mut(self, i); + } +} + +fn search_trait_bound( + generic_type_name: &str, + bound: &TypeParamBound, +) -> Vec<(String, PathSegment)> { + let mut inputs = vec![]; + + if let TypeParamBound::Trait(trait_bound) = bound { + let segment = &trait_bound.path.segments[trait_bound.path.segments.len() - 1]; + let name = segment.ident.to_string(); + if name.eq("Future") { + // match Future<Output=Type> + if let PathArguments::AngleBracketed(args) = &segment.arguments { + // binding: Output=Type + if let GenericArgument::AssocType(binding) = &args.args[0] { + if let Type::Path(p) = &binding.ty { + inputs.push((generic_type_name.to_owned(), p.path.segments[0].clone())); + } + } + } + } + } + inputs +} |