diff options
Diffstat (limited to 'src/lib.rs')
-rw-r--r-- | src/lib.rs | 375 |
1 files changed, 264 insertions, 111 deletions
@@ -1,73 +1,84 @@ extern crate proc_macro; mod enum_hack; +mod error; -use proc_macro2::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree}; +use crate::error::{Error, Result}; +use proc_macro::{ + token_stream, Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree, +}; use proc_macro_hack::proc_macro_hack; -use quote::{quote, ToTokens}; -use std::iter::FromIterator; -use syn::parse::{Error, Parse, ParseStream, Parser, Result}; -use syn::{parenthesized, parse_macro_input, Lit, LitStr, Token}; +use std::iter::{self, FromIterator, Peekable}; +use std::panic; #[proc_macro] -pub fn item(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = parse_macro_input!(input as PasteInput); - proc_macro::TokenStream::from(input.expanded) +pub fn item(input: TokenStream) -> TokenStream { + expand_paste(input) } #[proc_macro] -pub fn item_with_macros(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = parse_macro_input!(input as PasteInput); - proc_macro::TokenStream::from(enum_hack::wrap(input.expanded)) +pub fn item_with_macros(input: TokenStream) -> TokenStream { + enum_hack::wrap(expand_paste(input)) } #[proc_macro_hack] -pub fn expr(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = parse_macro_input!(input as PasteInput); - let output = input.expanded; - proc_macro::TokenStream::from(quote!({ #output })) +pub fn expr(input: TokenStream) -> TokenStream { + TokenStream::from(TokenTree::Group(Group::new( + Delimiter::Brace, + expand_paste(input), + ))) } #[doc(hidden)] #[proc_macro_derive(EnumHack)] -pub fn enum_hack(input: proc_macro::TokenStream) -> proc_macro::TokenStream { +pub fn enum_hack(input: TokenStream) -> TokenStream { enum_hack::extract(input) } -struct PasteInput { - expanded: TokenStream, -} - -impl Parse for PasteInput { - fn parse(input: ParseStream) -> Result<Self> { - let mut contains_paste = false; - let expanded = parse(input, &mut contains_paste)?; - Ok(PasteInput { expanded }) +fn expand_paste(input: TokenStream) -> TokenStream { + let mut contains_paste = false; + match expand(input, &mut contains_paste) { + Ok(expanded) => expanded, + Err(err) => err.to_compile_error(), } } -fn parse(input: ParseStream, contains_paste: &mut bool) -> Result<TokenStream> { +fn expand(input: TokenStream, contains_paste: &mut bool) -> Result<TokenStream> { let mut expanded = TokenStream::new(); - let (mut prev_colons, mut colons) = (false, false); - while !input.is_empty() { - let save = input.fork(); - match input.parse()? { - TokenTree::Group(group) => { + let (mut prev_colon, mut colon) = (false, false); + let mut prev_none_group = None::<Group>; + let mut tokens = input.into_iter().peekable(); + loop { + let token = tokens.next(); + if let Some(group) = prev_none_group.take() { + if match (&token, tokens.peek()) { + (Some(TokenTree::Punct(fst)), Some(TokenTree::Punct(snd))) => { + fst.as_char() == ':' && snd.as_char() == ':' && fst.spacing() == Spacing::Joint + } + _ => false, + } { + expanded.extend(group.stream()); + *contains_paste = true; + } else { + expanded.extend(iter::once(TokenTree::Group(group))); + } + } + match token { + Some(TokenTree::Group(group)) => { let delimiter = group.delimiter(); let content = group.stream(); let span = group.span(); if delimiter == Delimiter::Bracket && is_paste_operation(&content) { - let segments = parse_bracket_as_segments.parse2(content)?; + let segments = parse_bracket_as_segments(content, span)?; let pasted = paste_segments(span, &segments)?; - pasted.to_tokens(&mut expanded); + expanded.extend(pasted); *contains_paste = true; } else if is_none_delimited_flat_group(delimiter, &content) { - content.to_tokens(&mut expanded); + expanded.extend(content); *contains_paste = true; } else { let mut group_contains_paste = false; - let nested = (|input: ParseStream| parse(input, &mut group_contains_paste)) - .parse2(content)?; + let nested = expand(content, &mut group_contains_paste)?; let group = if group_contains_paste { let mut group = Group::new(delimiter, nested); group.set_span(span); @@ -76,26 +87,34 @@ fn parse(input: ParseStream, contains_paste: &mut bool) -> Result<TokenStream> { } else { group.clone() }; - let in_path = prev_colons || input.peek(Token![::]); - if in_path && delimiter == Delimiter::None { - group.stream().to_tokens(&mut expanded); + if delimiter != Delimiter::None { + expanded.extend(iter::once(TokenTree::Group(group))); + } else if prev_colon { + expanded.extend(group.stream()); *contains_paste = true; } else { - group.to_tokens(&mut expanded); + prev_none_group = Some(group); + } + } + prev_colon = false; + colon = false; + } + Some(other) => { + match &other { + TokenTree::Punct(punct) if punct.as_char() == ':' => { + prev_colon = colon; + colon = punct.spacing() == Spacing::Joint; + } + _ => { + prev_colon = false; + colon = false; } } + expanded.extend(iter::once(other)); } - other => other.to_tokens(&mut expanded), + None => return Ok(expanded), } - prev_colons = colons; - colons = save.peek(Token![::]); } - Ok(expanded) -} - -fn is_paste_operation(input: &TokenStream) -> bool { - let input = input.clone(); - parse_bracket_as_segments.parse2(input).is_ok() } // https://github.com/dtolnay/paste/issues/26 @@ -140,61 +159,173 @@ fn is_none_delimited_flat_group(delimiter: Delimiter, input: &TokenStream) -> bo state == State::Ident || state == State::Literal || state == State::Lifetime } +struct LitStr { + value: String, + span: Span, +} + +struct Colon { + span: Span, +} + enum Segment { String(String), Apostrophe(Span), Env(LitStr), - Modifier(Token![:], Ident), + Modifier(Colon, Ident), } -fn parse_bracket_as_segments(input: ParseStream) -> Result<Vec<Segment>> { - input.parse::<Token![<]>()?; +fn is_paste_operation(input: &TokenStream) -> bool { + let mut tokens = input.clone().into_iter(); - let segments = parse_segments(input)?; + match &tokens.next() { + Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {} + _ => return false, + } - input.parse::<Token![>]>()?; - if !input.is_empty() { - return Err(input.error("invalid input")); + let mut has_token = false; + loop { + match &tokens.next() { + Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => { + return has_token && tokens.next().is_none(); + } + Some(_) => has_token = true, + None => return false, + } } - Ok(segments) } -fn parse_segments(input: ParseStream) -> Result<Vec<Segment>> { +fn parse_bracket_as_segments(input: TokenStream, scope: Span) -> Result<Vec<Segment>> { + let mut tokens = input.into_iter().peekable(); + + match &tokens.next() { + Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {} + Some(wrong) => return Err(Error::new(wrong.span(), "expected `<`")), + None => return Err(Error::new(scope, "expected `[< ... >]`")), + } + + let segments = parse_segments(&mut tokens, scope)?; + + match &tokens.next() { + Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {} + Some(wrong) => return Err(Error::new(wrong.span(), "expected `>`")), + None => return Err(Error::new(scope, "expected `[< ... >]`")), + } + + match tokens.next() { + Some(unexpected) => Err(Error::new( + unexpected.span(), + "unexpected input, expected `[< ... >]`", + )), + None => Ok(segments), + } +} + +fn parse_segments( + tokens: &mut Peekable<token_stream::IntoIter>, + scope: Span, +) -> Result<Vec<Segment>> { let mut segments = Vec::new(); - while !(input.is_empty() || input.peek(Token![>])) { - match input.parse()? { + while match tokens.peek() { + None => false, + Some(TokenTree::Punct(punct)) => punct.as_char() != '>', + Some(_) => true, + } { + match tokens.next().unwrap() { TokenTree::Ident(ident) => { let mut fragment = ident.to_string(); if fragment.starts_with("r#") { fragment = fragment.split_off(2); } - if fragment == "env" && input.peek(Token![!]) { - input.parse::<Token![!]>()?; - let arg; - parenthesized!(arg in input); - let var: LitStr = arg.parse()?; - segments.push(Segment::Env(var)); + if fragment == "env" + && match tokens.peek() { + Some(TokenTree::Punct(punct)) => punct.as_char() == '!', + _ => false, + } + { + tokens.next().unwrap(); // `!` + let expect_group = tokens.next(); + let parenthesized = match &expect_group { + Some(TokenTree::Group(group)) + if group.delimiter() == Delimiter::Parenthesis => + { + group + } + Some(wrong) => return Err(Error::new(wrong.span(), "expected `(`")), + None => return Err(Error::new(scope, "expected `(` after `env!`")), + }; + let mut inner = parenthesized.stream().into_iter(); + let lit = match inner.next() { + Some(TokenTree::Literal(lit)) => lit, + Some(wrong) => { + return Err(Error::new(wrong.span(), "expected string literal")) + } + None => { + return Err(Error::new2( + ident.span(), + parenthesized.span(), + "expected string literal as argument to env! macro", + )) + } + }; + let lit_string = lit.to_string(); + if lit_string.starts_with('"') + && lit_string.ends_with('"') + && lit_string.len() >= 2 + { + // TODO: maybe handle escape sequences in the string if + // someone has a use case. + segments.push(Segment::Env(LitStr { + value: lit_string[1..lit_string.len() - 1].to_owned(), + span: lit.span(), + })); + } else { + return Err(Error::new(lit.span(), "expected string literal")); + } + if let Some(unexpected) = inner.next() { + return Err(Error::new( + unexpected.span(), + "unexpected token in env! macro", + )); + } } else { segments.push(Segment::String(fragment)); } } TokenTree::Literal(lit) => { - let value = match syn::parse_str(&lit.to_string())? { - Lit::Str(string) => string.value().replace('-', "_"), - Lit::Int(_) => lit.to_string(), - _ => return Err(Error::new(lit.span(), "unsupported literal")), - }; - segments.push(Segment::String(value)); + let mut lit_string = lit.to_string(); + if lit_string.contains(&['#', '\\', '.', '+'][..]) { + return Err(Error::new(lit.span(), "unsupported literal")); + } + lit_string = lit_string + .replace('"', "") + .replace('\'', "") + .replace('-', "_"); + segments.push(Segment::String(lit_string)); } TokenTree::Punct(punct) => match punct.as_char() { - '_' => segments.push(Segment::String("_".to_string())), + '_' => segments.push(Segment::String("_".to_owned())), '\'' => segments.push(Segment::Apostrophe(punct.span())), - ':' => segments.push(Segment::Modifier(Token![:](punct.span()), input.parse()?)), + ':' => { + let colon = Colon { span: punct.span() }; + let ident = match tokens.next() { + Some(TokenTree::Ident(ident)) => ident, + wrong => { + let span = wrong.as_ref().map_or(scope, TokenTree::span); + return Err(Error::new(span, "expected identifier after `:`")); + } + }; + segments.push(Segment::Modifier(colon, ident)); + } _ => return Err(Error::new(punct.span(), "unexpected punct")), }, TokenTree::Group(group) => { if group.delimiter() == Delimiter::None { - let nested = parse_segments.parse2(group.stream())?; + let mut inner = group.stream().into_iter().peekable(); + let nested = parse_segments(&mut inner, group.span())?; + if let Some(unexpected) = inner.next() { + return Err(Error::new(unexpected.span(), "unexpected token")); + } segments.extend(nested); } else { return Err(Error::new(group.span(), "unexpected token")); @@ -221,65 +352,87 @@ fn paste_segments(span: Span, segments: &[Segment]) -> Result<TokenStream> { is_lifetime = true; } Segment::Env(var) => { - let resolved = match std::env::var(var.value()) { + let resolved = match std::env::var(&var.value) { Ok(resolved) => resolved, Err(_) => { - return Err(Error::new(var.span(), "no such env var")); + return Err(Error::new( + var.span, + &format!("no such env var: {:?}", var.value), + )); } }; let resolved = resolved.replace('-', "_"); evaluated.push(resolved); } Segment::Modifier(colon, ident) => { - let span = quote!(#colon #ident); let last = match evaluated.pop() { Some(last) => last, - None => return Err(Error::new_spanned(span, "unexpected modifier")), + None => { + return Err(Error::new2(colon.span, ident.span(), "unexpected modifier")) + } }; - if ident == "lower" { - evaluated.push(last.to_lowercase()); - } else if ident == "upper" { - evaluated.push(last.to_uppercase()); - } else if ident == "snake" { - let mut acc = String::new(); - let mut prev = '_'; - for ch in last.chars() { - if ch.is_uppercase() && prev != '_' { - acc.push('_'); + match ident.to_string().as_str() { + "lower" => { + evaluated.push(last.to_lowercase()); + } + "upper" => { + evaluated.push(last.to_uppercase()); + } + "snake" => { + let mut acc = String::new(); + let mut prev = '_'; + for ch in last.chars() { + if ch.is_uppercase() && prev != '_' { + acc.push('_'); + } + acc.push(ch); + prev = ch; } - acc.push(ch); - prev = ch; + evaluated.push(acc.to_lowercase()); } - evaluated.push(acc.to_lowercase()); - } else if ident == "camel" { - let mut acc = String::new(); - let mut prev = '_'; - for ch in last.chars() { - if ch != '_' { - if prev == '_' { - for chu in ch.to_uppercase() { - acc.push(chu); - } - } else if prev.is_uppercase() { - for chl in ch.to_lowercase() { - acc.push(chl); + "camel" => { + let mut acc = String::new(); + let mut prev = '_'; + for ch in last.chars() { + if ch != '_' { + if prev == '_' { + for chu in ch.to_uppercase() { + acc.push(chu); + } + } else if prev.is_uppercase() { + for chl in ch.to_lowercase() { + acc.push(chl); + } + } else { + acc.push(ch); } - } else { - acc.push(ch); } + prev = ch; } - prev = ch; + evaluated.push(acc); + } + _ => { + return Err(Error::new2( + colon.span, + ident.span(), + "unsupported modifier", + )); } - evaluated.push(acc); - } else { - return Err(Error::new_spanned(span, "unsupported modifier")); } } } } let pasted = evaluated.into_iter().collect::<String>(); - let ident = TokenTree::Ident(Ident::new(&pasted, span)); + let ident = match panic::catch_unwind(|| Ident::new(&pasted, span)) { + Ok(ident) => TokenTree::Ident(ident), + Err(_) => { + return Err(Error::new( + span, + &format!("`{:?}` is not a valid identifier", pasted), + )); + } + }; let tokens = if is_lifetime { let apostrophe = TokenTree::Punct(Punct::new('\'', Spacing::Joint)); vec![apostrophe, ident] |