Skip to main content

dfir_lang/
parse.rs

1//! AST for surface syntax, modelled on [`syn`]'s ASTs.
2#![allow(clippy::allow_attributes, missing_docs, reason = "internal use")]
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use proc_macro2::{Span, TokenStream};
8use quote::ToTokens;
9use syn::parse::discouraged::Speculative;
10use syn::parse::{Parse, ParseStream, Parser};
11use syn::punctuated::Punctuated;
12use syn::token::{Brace, Bracket, Paren};
13use syn::{
14    AngleBracketedGenericArguments, Expr, ExprPath, GenericArgument, Ident, ItemUse, LitInt, Path,
15    PathArguments, PathSegment, Token, braced, bracketed, parenthesized,
16};
17
18use crate::process_singletons::preprocess_singletons;
19
20pub struct DfirCode {
21    pub statements: Vec<DfirStatement>,
22}
23impl Parse for DfirCode {
24    fn parse(input: ParseStream) -> syn::Result<Self> {
25        let mut statements = Vec::new();
26        while !input.is_empty() {
27            statements.push(input.parse()?);
28        }
29        Ok(DfirCode { statements })
30    }
31}
32impl ToTokens for DfirCode {
33    fn to_tokens(&self, tokens: &mut TokenStream) {
34        for statement in self.statements.iter() {
35            statement.to_tokens(tokens);
36        }
37    }
38}
39
40pub enum DfirStatement {
41    Use(ItemUse),
42    Named(NamedStatement),
43    Pipeline(PipelineStatement),
44    Loop(LoopStatement),
45}
46impl Parse for DfirStatement {
47    fn parse(input: ParseStream) -> syn::Result<Self> {
48        let lookahead1 = input.lookahead1();
49        if lookahead1.peek(Token![use]) {
50            Ok(Self::Use(ItemUse::parse(input)?))
51        } else if lookahead1.peek(Paren) || lookahead1.peek(Bracket) || lookahead1.peek(Token![mod])
52        {
53            Ok(Self::Pipeline(PipelineStatement::parse(input)?))
54        } else if lookahead1.peek(Token![loop]) {
55            Ok(Self::Loop(LoopStatement::parse(input)?))
56        } else if lookahead1.peek(Ident) {
57            let fork = input.fork();
58            let _: Path = fork.parse()?;
59            let lookahead2 = fork.lookahead1();
60            if lookahead2.peek(Token![=]) {
61                Ok(Self::Named(NamedStatement::parse(input)?))
62            } else if lookahead2.peek(Token![->])
63                || lookahead2.peek(Paren)
64                || lookahead2.peek(Bracket)
65            {
66                Ok(Self::Pipeline(PipelineStatement::parse(input)?))
67            } else {
68                Err(lookahead2.error())
69            }
70        } else {
71            Err(lookahead1.error())
72        }
73    }
74}
75impl ToTokens for DfirStatement {
76    fn to_tokens(&self, tokens: &mut TokenStream) {
77        match self {
78            Self::Use(x) => x.to_tokens(tokens),
79            Self::Named(x) => x.to_tokens(tokens),
80            Self::Pipeline(x) => x.to_tokens(tokens),
81            Self::Loop(x) => x.to_tokens(tokens),
82        }
83    }
84}
85
86pub struct NamedStatement {
87    pub name: Ident,
88    pub equals: Token![=],
89    pub pipeline: Pipeline,
90    pub semi_token: Token![;],
91}
92impl Parse for NamedStatement {
93    fn parse(input: ParseStream) -> syn::Result<Self> {
94        let name = input.parse()?;
95        let equals = input.parse()?;
96        let pipeline = input.parse()?;
97        let semi_token = input.parse()?;
98        Ok(Self {
99            name,
100            equals,
101            pipeline,
102            semi_token,
103        })
104    }
105}
106impl ToTokens for NamedStatement {
107    fn to_tokens(&self, tokens: &mut TokenStream) {
108        self.name.to_tokens(tokens);
109        self.equals.to_tokens(tokens);
110        self.pipeline.to_tokens(tokens);
111        self.semi_token.to_tokens(tokens);
112    }
113}
114
115pub struct PipelineStatement {
116    pub pipeline: Pipeline,
117    pub semi_token: Token![;],
118}
119impl Parse for PipelineStatement {
120    fn parse(input: ParseStream) -> syn::Result<Self> {
121        let pipeline = input.parse()?;
122        let semi_token = input.parse()?;
123        Ok(Self {
124            pipeline,
125            semi_token,
126        })
127    }
128}
129impl ToTokens for PipelineStatement {
130    fn to_tokens(&self, tokens: &mut TokenStream) {
131        self.pipeline.to_tokens(tokens);
132        self.semi_token.to_tokens(tokens);
133    }
134}
135
136#[derive(Clone, Debug)]
137pub enum Pipeline {
138    Paren(Ported<PipelineParen>),
139    Name(Ported<Ident>),
140    Link(PipelineLink),
141    Operator(Operator),
142    ModuleBoundary(Ported<Token![mod]>),
143}
144impl Pipeline {
145    fn parse_one(input: ParseStream) -> syn::Result<Self> {
146        let lookahead1 = input.lookahead1();
147
148        // Leading indexing
149        if lookahead1.peek(Bracket) {
150            let inn_idx = input.parse()?;
151            let lookahead2 = input.lookahead1();
152            // Indexed paren
153            if lookahead2.peek(Paren) {
154                Ok(Self::Paren(Ported::parse_rest(Some(inn_idx), input)?))
155            }
156            // Indexed name
157            else if lookahead2.peek(Ident) {
158                Ok(Self::Name(Ported::parse_rest(Some(inn_idx), input)?))
159            }
160            // Indexed module boundary
161            else if lookahead2.peek(Token![mod]) {
162                Ok(Self::ModuleBoundary(Ported::parse_rest(
163                    Some(inn_idx),
164                    input,
165                )?))
166            }
167            // Emit lookahead expected tokens errors.
168            else {
169                Err(lookahead2.error())
170            }
171        // module input/output
172        } else if lookahead1.peek(Token![mod]) {
173            Ok(Self::ModuleBoundary(input.parse()?))
174        // Ident or macro-style expression
175        } else if lookahead1.peek(Ident) {
176            let speculative = input.fork();
177            let _ident: Ident = speculative.parse()?;
178
179            // If has paren or generic next, it's an operator
180            if speculative.peek(Paren)
181                || speculative.peek(Token![<])
182                || speculative.peek(Token![::])
183            {
184                Ok(Self::Operator(input.parse()?))
185            }
186            // Otherwise it's a variable name
187            else {
188                Ok(Self::Name(input.parse()?))
189            }
190        }
191        // Paren group
192        else if lookahead1.peek(Paren) {
193            Ok(Self::Paren(input.parse()?))
194        }
195        // Emit lookahead expected tokens errors.
196        else {
197            Err(lookahead1.error())
198        }
199    }
200}
201impl Parse for Pipeline {
202    fn parse(input: ParseStream) -> syn::Result<Self> {
203        let lhs = Pipeline::parse_one(input)?;
204        if input.is_empty() || input.peek(Token![;]) {
205            Ok(lhs)
206        } else {
207            let arrow = input.parse()?;
208            let rhs = input.parse()?;
209            let lhs = Box::new(lhs);
210            Ok(Self::Link(PipelineLink { lhs, arrow, rhs }))
211        }
212    }
213}
214impl ToTokens for Pipeline {
215    fn to_tokens(&self, tokens: &mut TokenStream) {
216        match self {
217            Self::Paren(x) => x.to_tokens(tokens),
218            Self::Link(x) => x.to_tokens(tokens),
219            Self::Name(x) => x.to_tokens(tokens),
220            Self::Operator(x) => x.to_tokens(tokens),
221            Self::ModuleBoundary(x) => x.to_tokens(tokens),
222        }
223    }
224}
225
226pub struct LoopStatement {
227    pub loop_token: Token![loop],
228    pub ident: Option<Ident>,
229    pub brace_token: Brace,
230    pub statements: Vec<DfirStatement>,
231    pub semi_token: Token![;],
232}
233impl Parse for LoopStatement {
234    fn parse(input: ParseStream) -> syn::Result<Self> {
235        let loop_token = input.parse()?;
236        let ident = input.parse()?;
237        let content;
238        let brace_token = braced!(content in input);
239        let mut statements = Vec::new();
240        while !content.is_empty() {
241            statements.push(content.parse()?);
242        }
243        let semi_token = input.parse()?;
244        Ok(Self {
245            loop_token,
246            ident,
247            brace_token,
248            statements,
249            semi_token,
250        })
251    }
252}
253impl ToTokens for LoopStatement {
254    fn to_tokens(&self, tokens: &mut TokenStream) {
255        self.loop_token.to_tokens(tokens);
256        self.ident.to_tokens(tokens);
257        self.brace_token.surround(tokens, |tokens| {
258            for statement in self.statements.iter() {
259                statement.to_tokens(tokens);
260            }
261        });
262        self.semi_token.to_tokens(tokens);
263    }
264}
265
266#[derive(Clone, Debug)]
267pub struct Ported<Inner> {
268    pub inn: Option<Indexing>,
269    pub inner: Inner,
270    pub out: Option<Indexing>,
271}
272impl<Inner> Ported<Inner>
273where
274    Inner: Parse,
275{
276    /// The caller will often parse the first port (`inn`) as part of determining what to parse
277    /// next, so this will do the rest after that.
278    fn parse_rest(inn: Option<Indexing>, input: ParseStream) -> syn::Result<Self> {
279        let inner = input.parse()?;
280        let out = input.call(Indexing::parse_opt)?;
281        Ok(Self { inn, inner, out })
282    }
283}
284impl<Inner> Parse for Ported<Inner>
285where
286    Inner: Parse,
287{
288    fn parse(input: ParseStream) -> syn::Result<Self> {
289        let inn = input.call(Indexing::parse_opt)?;
290        Self::parse_rest(inn, input)
291    }
292}
293impl<Inner> ToTokens for Ported<Inner>
294where
295    Inner: ToTokens,
296{
297    fn to_tokens(&self, tokens: &mut TokenStream) {
298        self.inn.to_tokens(tokens);
299        self.inner.to_tokens(tokens);
300        self.out.to_tokens(tokens);
301    }
302}
303
304#[derive(Clone, Debug)]
305pub struct PipelineParen {
306    pub paren_token: Paren,
307    pub pipeline: Box<Pipeline>,
308}
309impl Parse for PipelineParen {
310    fn parse(input: ParseStream) -> syn::Result<Self> {
311        let content;
312        let paren_token = parenthesized!(content in input);
313        let pipeline = content.parse()?;
314        Ok(Self {
315            paren_token,
316            pipeline,
317        })
318    }
319}
320impl ToTokens for PipelineParen {
321    fn to_tokens(&self, tokens: &mut TokenStream) {
322        self.paren_token.surround(tokens, |tokens| {
323            self.pipeline.to_tokens(tokens);
324        });
325    }
326}
327
328#[derive(Clone, Debug)]
329pub struct PipelineLink {
330    pub lhs: Box<Pipeline>,
331    pub arrow: Token![->],
332    pub rhs: Box<Pipeline>,
333}
334impl Parse for PipelineLink {
335    fn parse(input: ParseStream) -> syn::Result<Self> {
336        let lhs = input.parse()?;
337        let arrow = input.parse()?;
338        let rhs = input.parse()?;
339
340        Ok(Self { lhs, arrow, rhs })
341    }
342}
343impl ToTokens for PipelineLink {
344    fn to_tokens(&self, tokens: &mut TokenStream) {
345        self.lhs.to_tokens(tokens);
346        self.arrow.to_tokens(tokens);
347        self.rhs.to_tokens(tokens);
348    }
349}
350
351#[derive(Clone, Debug)]
352pub struct Indexing {
353    pub bracket_token: Bracket,
354    pub index: PortIndex,
355}
356impl Indexing {
357    fn parse_opt(input: ParseStream) -> syn::Result<Option<Self>> {
358        input.peek(Bracket).then(|| input.parse()).transpose()
359    }
360}
361impl Parse for Indexing {
362    fn parse(input: ParseStream) -> syn::Result<Self> {
363        let content;
364        let bracket_token = bracketed!(content in input);
365        let index = content.parse()?;
366        Ok(Self {
367            bracket_token,
368            index,
369        })
370    }
371}
372impl ToTokens for Indexing {
373    fn to_tokens(&self, tokens: &mut TokenStream) {
374        self.bracket_token.surround(tokens, |tokens| {
375            self.index.to_tokens(tokens);
376        });
377    }
378}
379
380/// Port can either be an int or a name (path).
381#[derive(Clone, Debug)]
382pub enum PortIndex {
383    Int(IndexInt),
384    Path(ExprPath),
385}
386impl Parse for PortIndex {
387    fn parse(input: ParseStream) -> syn::Result<Self> {
388        let lookahead = input.lookahead1();
389        if lookahead.peek(LitInt) {
390            input.parse().map(Self::Int)
391        } else {
392            input.parse().map(Self::Path)
393        }
394    }
395}
396impl ToTokens for PortIndex {
397    fn to_tokens(&self, tokens: &mut TokenStream) {
398        match self {
399            PortIndex::Int(index_int) => index_int.to_tokens(tokens),
400            PortIndex::Path(expr_path) => expr_path.to_tokens(tokens),
401        }
402    }
403}
404
405struct TypeHintRemover;
406impl syn::visit_mut::VisitMut for TypeHintRemover {
407    fn visit_expr_mut(&mut self, expr: &mut Expr) {
408        if let Expr::Call(expr_call) = expr &&
409            let Expr::Path(path) = expr_call.func.as_ref() &&
410                // if it is a call of the form `::...::*_type_hint(xyz)`,
411                // typically `::stageleft::...`, replace it with `xyz`
412                path
413                    .path
414                    .segments
415                    .last()
416                    .unwrap()
417                    .ident
418                    .to_string()
419                    .ends_with("_type_hint")
420        {
421            *expr = expr_call.args.first().unwrap().clone();
422        }
423
424        syn::visit_mut::visit_expr_mut(self, expr);
425    }
426}
427
428#[derive(Clone)]
429pub struct Operator {
430    pub path: Path,
431    pub paren_token: Paren,
432    pub args_raw: TokenStream,
433    pub args: Punctuated<Expr, Token![,]>,
434    pub singletons_referenced: Vec<SingletonRef>,
435}
436
437impl Operator {
438    pub fn name(&self) -> Path {
439        Path {
440            leading_colon: self.path.leading_colon,
441            segments: self
442                .path
443                .segments
444                .iter()
445                .map(|seg| PathSegment {
446                    ident: seg.ident.clone(),
447                    arguments: PathArguments::None,
448                })
449                .collect(),
450        }
451    }
452
453    pub fn name_string(&self) -> String {
454        self.name().to_token_stream().to_string()
455    }
456
457    pub fn type_arguments(&self) -> Option<&Punctuated<GenericArgument, Token![,]>> {
458        let end = self.path.segments.last()?;
459        if let PathArguments::AngleBracketed(type_args) = &end.arguments {
460            Some(&type_args.args)
461        } else {
462            None
463        }
464    }
465
466    pub fn args(&self) -> &Punctuated<Expr, Token![,]> {
467        &self.args
468    }
469
470    /// Output the operator as a formatted string using `prettyplease`.
471    pub fn to_pretty_string(&self) -> String {
472        // TODO(mingwei): preserve #args_raw instead of just args?
473        let mut file: syn::File = syn::parse_quote! {
474            fn main() {
475                #self
476            }
477        };
478
479        syn::visit_mut::visit_file_mut(&mut TypeHintRemover, &mut file);
480        let str = prettyplease::unparse(&file);
481        str.trim_start()
482            .trim_start_matches("fn main()")
483            .trim_start()
484            .trim_start_matches('{')
485            .trim_start()
486            .trim_end()
487            .trim_end_matches('}')
488            .trim_end()
489            .replace("\n    ", "\n") // Remove extra leading indent
490    }
491}
492impl Parse for Operator {
493    fn parse(input: ParseStream) -> syn::Result<Self> {
494        let path: Path = input.parse()?;
495        if let Some(path_seg) = path.segments.iter().find(|path_seg| {
496            matches!(
497                &path_seg.arguments,
498                PathArguments::AngleBracketed(AngleBracketedGenericArguments {
499                    colon2_token: None,
500                    ..
501                })
502            )
503        }) {
504            return Err(syn::Error::new_spanned(
505                path_seg,
506                "Missing `::` before `<...>` generic arguments",
507            ));
508        }
509
510        let content;
511        let paren_token = parenthesized!(content in input);
512        let args_raw: TokenStream = content.parse()?;
513        let mut singletons_referenced: Vec<SingletonRef> = Vec::new();
514        let args = Punctuated::parse_terminated.parse2(preprocess_singletons(
515            args_raw.clone(),
516            &mut singletons_referenced,
517        ))?;
518
519        Ok(Self {
520            path,
521            paren_token,
522            args_raw,
523            args,
524            singletons_referenced,
525        })
526    }
527}
528
529impl ToTokens for Operator {
530    fn to_tokens(&self, tokens: &mut TokenStream) {
531        self.path.to_tokens(tokens);
532        self.paren_token.surround(tokens, |tokens| {
533            self.args.to_tokens(tokens);
534        });
535    }
536}
537
538impl Debug for Operator {
539    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
540        f.debug_struct("Operator")
541            .field("path", &self.path.to_token_stream().to_string())
542            .field(
543                "args",
544                &self
545                    .args
546                    .iter()
547                    .map(|a| a.to_token_stream().to_string())
548                    .collect::<Vec<_>>(),
549            )
550            .finish()
551    }
552}
553
554#[derive(Clone, Copy, Debug)]
555pub struct IndexInt {
556    pub value: isize,
557    pub span: Span,
558}
559impl Parse for IndexInt {
560    fn parse(input: ParseStream) -> syn::Result<Self> {
561        let lit_int: LitInt = input.parse()?;
562        let value = lit_int.base10_parse()?;
563        Ok(Self {
564            value,
565            span: lit_int.span(),
566        })
567    }
568}
569impl ToTokens for IndexInt {
570    fn to_tokens(&self, tokens: &mut TokenStream) {
571        let lit_int = LitInt::new(&self.value.to_string(), self.span);
572        lit_int.to_tokens(tokens)
573    }
574}
575impl Hash for IndexInt {
576    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
577        self.value.hash(state);
578    }
579}
580impl PartialEq for IndexInt {
581    fn eq(&self, other: &Self) -> bool {
582        self.value == other.value
583    }
584}
585impl Eq for IndexInt {}
586impl PartialOrd for IndexInt {
587    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
588        Some(self.cmp(other))
589    }
590}
591impl Ord for IndexInt {
592    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
593        self.value.cmp(&other.value)
594    }
595}
596
597/// A parsed singleton reference token with mutability and optional access group.
598///
599/// Syntax: `#var`, `#mut var`, `#{N} var`, `#{N} mut var`
600#[derive(Clone, Debug)]
601pub struct SingletonRef {
602    /// Hash `#` marking the start of the singleton.
603    pub hash: Token![#],
604    /// Optional access group for ordering (`#{N}` prefix). Stores the brace group and parsed integer.
605    pub access_group: Option<(Brace, LitInt)>,
606    /// Whether this is a mutable reference (`#mut var` or `#{N} mut var`).
607    pub token_mut: Option<Token![mut]>,
608    /// The variable name being referenced.
609    pub ident: Ident,
610}
611
612impl SingletonRef {
613    /// Returns a parsed singleton reference token (if valid) and all remaining tokens.
614    pub fn try_parse(input: ParseStream) -> syn::Result<(Option<Self>, TokenStream)> {
615        let this = if input.peek(Token![#]) {
616            let fork = input.fork();
617            if let Ok(this) = fork.parse() {
618                input.advance_to(&fork);
619                Some(this)
620            } else {
621                None
622            }
623        } else {
624            None
625        };
626        let tokens = input.parse().expect("infallible");
627        Ok((this, tokens))
628    }
629}
630
631impl Parse for SingletonRef {
632    fn parse(input: ParseStream) -> syn::Result<Self> {
633        let hash = input.parse()?;
634        let access_group = input
635            .peek(Brace)
636            .then(|| {
637                let inner;
638                let brace = braced!(inner in input);
639                let lit_int = inner.parse()?;
640                if !inner.is_empty() {
641                    return Err(inner.error("expected only an integer"));
642                }
643                Ok((brace, lit_int))
644            })
645            .transpose()?;
646        let token_mut = input.parse()?;
647        let ident = input.parse()?;
648        Ok(Self {
649            hash,
650            token_mut,
651            access_group,
652            ident,
653        })
654    }
655}
656
657impl ToTokens for SingletonRef {
658    fn to_tokens(&self, tokens: &mut TokenStream) {
659        self.hash.to_tokens(tokens);
660        if let Some((brace, lit_int)) = self.access_group.as_ref() {
661            brace.surround(tokens, |tokens| {
662                lit_int.to_tokens(tokens);
663            });
664        }
665        self.token_mut.to_tokens(tokens);
666        self.ident.to_tokens(tokens);
667    }
668}
669
670#[cfg(test)]
671mod test {
672    use syn::parse_quote;
673
674    use super::*;
675
676    #[test]
677    fn test_operator_to_pretty_string() {
678        let op: Operator = parse_quote! {
679            demux(|(msg, addr), var_args!(clients, msgs, errs)|
680                match msg {
681                    Message::ConnectRequest => clients.give(addr),
682                    Message::ChatMsg {..} => msgs.give(msg),
683                    _ => errs.give(msg),
684                }
685            )
686        };
687        assert_eq!(
688            r"
689demux(|(msg, addr), var_args!(clients, msgs, errs)| match msg {
690    Message::ConnectRequest => clients.give(addr),
691    Message::ChatMsg { .. } => msgs.give(msg),
692    _ => errs.give(msg),
693})
694"
695            .trim(),
696            op.to_pretty_string()
697        );
698    }
699}