//! Generate Rust code from a series of Sequences. use crate::sema::{ExternalSig, ReturnKind, Sym, Term, TermEnv, TermId, Type, TypeEnv, TypeId}; use crate::serialize::{Block, ControlFlow, EvalStep, MatchArm}; use crate::trie_again::{Binding, BindingId, Constraint, RuleSet}; use crate::StableSet; use std::fmt::Write; /// Options for code generation. #[derive(Clone, Debug, Default)] pub struct CodegenOptions { /// Do not include the `#![allow(...)]` pragmas in the generated /// source. Useful if it must be include!()'d elsewhere. pub exclude_global_allow_pragmas: bool, } /// Emit Rust source code for the given type and term environments. pub fn codegen( typeenv: &TypeEnv, termenv: &TermEnv, terms: &[(TermId, RuleSet)], options: &CodegenOptions, ) -> String { Codegen::compile(typeenv, termenv, terms).generate_rust(options) } #[derive(Clone, Debug)] struct Codegen<'a> { typeenv: &'a TypeEnv, termenv: &'a TermEnv, terms: &'a [(TermId, RuleSet)], } struct BodyContext<'a, W> { out: &'a mut W, ruleset: &'a RuleSet, indent: String, is_ref: StableSet, is_bound: StableSet, } impl<'a, W: Write> BodyContext<'a, W> { fn new(out: &'a mut W, ruleset: &'a RuleSet) -> Self { Self { out, ruleset, indent: Default::default(), is_ref: Default::default(), is_bound: Default::default(), } } fn enter_scope(&mut self) -> StableSet { let new = self.is_bound.clone(); std::mem::replace(&mut self.is_bound, new) } fn begin_block(&mut self) -> std::fmt::Result { self.indent.push_str(" "); writeln!(self.out, " {{") } fn end_block(&mut self, scope: StableSet) -> std::fmt::Result { self.is_bound = scope; self.end_block_without_newline()?; writeln!(self.out) } fn end_block_without_newline(&mut self) -> std::fmt::Result { self.indent.truncate(self.indent.len() - 4); write!(self.out, "{}}}", &self.indent) } fn set_ref(&mut self, binding: BindingId, is_ref: bool) { if is_ref { self.is_ref.insert(binding); } else { debug_assert!(!self.is_ref.contains(&binding)); } } } impl<'a> Codegen<'a> { fn compile( typeenv: &'a TypeEnv, termenv: &'a TermEnv, terms: &'a [(TermId, RuleSet)], ) -> Codegen<'a> { Codegen { typeenv, termenv, terms, } } fn generate_rust(&self, options: &CodegenOptions) -> String { let mut code = String::new(); self.generate_header(&mut code, options); self.generate_ctx_trait(&mut code); self.generate_internal_types(&mut code); self.generate_internal_term_constructors(&mut code).unwrap(); code } fn generate_header(&self, code: &mut String, options: &CodegenOptions) { writeln!(code, "// GENERATED BY ISLE. DO NOT EDIT!").unwrap(); writeln!(code, "//").unwrap(); writeln!( code, "// Generated automatically from the instruction-selection DSL code in:", ) .unwrap(); for file in &self.typeenv.filenames { writeln!(code, "// - {}", file).unwrap(); } if !options.exclude_global_allow_pragmas { writeln!( code, "\n#![allow(dead_code, unreachable_code, unreachable_patterns)]" ) .unwrap(); writeln!( code, "#![allow(unused_imports, unused_variables, non_snake_case, unused_mut)]" ) .unwrap(); writeln!( code, "#![allow(irrefutable_let_patterns, unused_assignments, non_camel_case_types)]" ) .unwrap(); } writeln!(code, "\nuse super::*; // Pulls in all external types.").unwrap(); writeln!(code, "use std::marker::PhantomData;").unwrap(); } fn generate_trait_sig(&self, code: &mut String, indent: &str, sig: &ExternalSig) { let ret_tuple = format!( "{open_paren}{rets}{close_paren}", open_paren = if sig.ret_tys.len() != 1 { "(" } else { "" }, rets = sig .ret_tys .iter() .map(|&ty| self.type_name(ty, /* by_ref = */ false)) .collect::>() .join(", "), close_paren = if sig.ret_tys.len() != 1 { ")" } else { "" }, ); if sig.ret_kind == ReturnKind::Iterator { writeln!( code, "{indent}type {name}_iter: ContextIter;", indent = indent, name = sig.func_name, output = ret_tuple, ) .unwrap(); } let ret_ty = match sig.ret_kind { ReturnKind::Plain => ret_tuple, ReturnKind::Option => format!("Option<{}>", ret_tuple), ReturnKind::Iterator => format!("Self::{}_iter", sig.func_name), }; writeln!( code, "{indent}fn {name}(&mut self, {params}) -> {ret_ty};", indent = indent, name = sig.func_name, params = sig .param_tys .iter() .enumerate() .map(|(i, &ty)| format!("arg{}: {}", i, self.type_name(ty, /* by_ref = */ true))) .collect::>() .join(", "), ret_ty = ret_ty, ) .unwrap(); } fn generate_ctx_trait(&self, code: &mut String) { writeln!(code).unwrap(); writeln!( code, "/// Context during lowering: an implementation of this trait" ) .unwrap(); writeln!( code, "/// must be provided with all external constructors and extractors." ) .unwrap(); writeln!( code, "/// A mutable borrow is passed along through all lowering logic." ) .unwrap(); writeln!(code, "pub trait Context {{").unwrap(); for term in &self.termenv.terms { if term.has_external_extractor() { let ext_sig = term.extractor_sig(self.typeenv).unwrap(); self.generate_trait_sig(code, " ", &ext_sig); } if term.has_external_constructor() { let ext_sig = term.constructor_sig(self.typeenv).unwrap(); self.generate_trait_sig(code, " ", &ext_sig); } } writeln!(code, "}}").unwrap(); writeln!( code, r#" pub trait ContextIter {{ type Context; type Output; fn next(&mut self, ctx: &mut Self::Context) -> Option; }} pub struct ContextIterWrapper, C: Context> {{ iter: I, _ctx: PhantomData, }} impl, C: Context> From for ContextIterWrapper {{ fn from(iter: I) -> Self {{ Self {{ iter, _ctx: PhantomData }} }} }} impl, C: Context> ContextIter for ContextIterWrapper {{ type Context = C; type Output = Item; fn next(&mut self, _ctx: &mut Self::Context) -> Option {{ self.iter.next() }} }} "#, ) .unwrap(); } fn generate_internal_types(&self, code: &mut String) { for ty in &self.typeenv.types { match ty { &Type::Enum { name, is_extern, is_nodebug, ref variants, pos, .. } if !is_extern => { let name = &self.typeenv.syms[name.index()]; writeln!( code, "\n/// Internal type {}: defined at {}.", name, pos.pretty_print_line(&self.typeenv.filenames[..]) ) .unwrap(); // Generate the `derive`s. let debug_derive = if is_nodebug { "" } else { ", Debug" }; if variants.iter().all(|v| v.fields.is_empty()) { writeln!( code, "#[derive(Copy, Clone, PartialEq, Eq{})]", debug_derive ) .unwrap(); } else { writeln!(code, "#[derive(Clone{})]", debug_derive).unwrap(); } writeln!(code, "pub enum {} {{", name).unwrap(); for variant in variants { let name = &self.typeenv.syms[variant.name.index()]; if variant.fields.is_empty() { writeln!(code, " {},", name).unwrap(); } else { writeln!(code, " {} {{", name).unwrap(); for field in &variant.fields { let name = &self.typeenv.syms[field.name.index()]; let ty_name = self.typeenv.types[field.ty.index()].name(self.typeenv); writeln!(code, " {}: {},", name, ty_name).unwrap(); } writeln!(code, " }},").unwrap(); } } writeln!(code, "}}").unwrap(); } _ => {} } } } fn type_name(&self, typeid: TypeId, by_ref: bool) -> String { match self.typeenv.types[typeid.index()] { Type::Primitive(_, sym, _) => self.typeenv.syms[sym.index()].clone(), Type::Enum { name, .. } => { let r = if by_ref { "&" } else { "" }; format!("{}{}", r, self.typeenv.syms[name.index()]) } } } fn generate_internal_term_constructors(&self, code: &mut String) -> std::fmt::Result { for &(termid, ref ruleset) in self.terms.iter() { let root = crate::serialize::serialize(ruleset); let mut ctx = BodyContext::new(code, ruleset); let termdata = &self.termenv.terms[termid.index()]; let term_name = &self.typeenv.syms[termdata.name.index()]; writeln!(ctx.out)?; writeln!( ctx.out, "{}// Generated as internal constructor for term {}.", &ctx.indent, term_name, )?; let sig = termdata.constructor_sig(self.typeenv).unwrap(); writeln!( ctx.out, "{}pub fn {}(", &ctx.indent, sig.func_name )?; writeln!(ctx.out, "{} ctx: &mut C,", &ctx.indent)?; for (i, &ty) in sig.param_tys.iter().enumerate() { let (is_ref, sym) = self.ty(ty); write!(ctx.out, "{} arg{}: ", &ctx.indent, i)?; write!( ctx.out, "{}{}", if is_ref { "&" } else { "" }, &self.typeenv.syms[sym.index()] )?; if let Some(binding) = ctx.ruleset.find_binding(&Binding::Argument { index: i.try_into().unwrap(), }) { ctx.set_ref(binding, is_ref); } writeln!(ctx.out, ",")?; } write!(ctx.out, "{}) -> ", &ctx.indent)?; let (_, ret) = self.ty(sig.ret_tys[0]); let ret = &self.typeenv.syms[ret.index()]; match sig.ret_kind { ReturnKind::Iterator => { write!(ctx.out, "impl ContextIter", ret)? } ReturnKind::Option => write!(ctx.out, "Option<{}>", ret)?, ReturnKind::Plain => write!(ctx.out, "{}", ret)?, }; let scope = ctx.enter_scope(); ctx.begin_block()?; if sig.ret_kind == ReturnKind::Iterator { writeln!( ctx.out, "{}let mut returns = ConstructorVec::new();", &ctx.indent )?; } self.emit_block(&mut ctx, &root, sig.ret_kind)?; match (sig.ret_kind, root.steps.last()) { (ReturnKind::Iterator, _) => { writeln!( ctx.out, "{}return ContextIterWrapper::from(returns.into_iter());", &ctx.indent )?; } (_, Some(EvalStep { check: ControlFlow::Return { .. }, .. })) => { // If there's an outermost fallback, no need for another `return` statement. } (ReturnKind::Option, _) => { writeln!(ctx.out, "{}None", &ctx.indent)? } (ReturnKind::Plain, _) => { writeln!(ctx.out, "unreachable!(\"no rule matched for term {{}} at {{}}; should it be partial?\", {:?}, {:?})", term_name, termdata .decl_pos .pretty_print_line(&self.typeenv.filenames[..]) )? } } ctx.end_block(scope)?; } Ok(()) } fn ty(&self, typeid: TypeId) -> (bool, Sym) { match &self.typeenv.types[typeid.index()] { &Type::Primitive(_, sym, _) => (false, sym), &Type::Enum { name, .. } => (true, name), } } fn emit_block( &self, ctx: &mut BodyContext, block: &Block, ret_kind: ReturnKind, ) -> std::fmt::Result { if !matches!(ret_kind, ReturnKind::Iterator) { // Loops are only allowed if we're returning an iterator. assert!(!block .steps .iter() .any(|c| matches!(c.check, ControlFlow::Loop { .. }))); // Unless we're returning an iterator, a case which returns a result must be the last // case in a block. if let Some(result_pos) = block .steps .iter() .position(|c| matches!(c.check, ControlFlow::Return { .. })) { assert_eq!(block.steps.len() - 1, result_pos); } } for case in block.steps.iter() { for &expr in case.bind_order.iter() { write!(ctx.out, "{}let v{} = ", &ctx.indent, expr.index())?; self.emit_expr(ctx, expr)?; writeln!(ctx.out, ";")?; ctx.is_bound.insert(expr); } match &case.check { // Use a shorthand notation if there's only one match arm. ControlFlow::Match { source, arms } if arms.len() == 1 => { let arm = &arms[0]; let scope = ctx.enter_scope(); match arm.constraint { Constraint::ConstInt { .. } | Constraint::ConstPrim { .. } => { write!(ctx.out, "{}if ", &ctx.indent)?; self.emit_expr(ctx, *source)?; write!(ctx.out, " == ")?; self.emit_constraint(ctx, *source, arm)?; } Constraint::Variant { .. } | Constraint::Some => { write!(ctx.out, "{}if let ", &ctx.indent)?; self.emit_constraint(ctx, *source, arm)?; write!(ctx.out, " = ")?; self.emit_source(ctx, *source, arm.constraint)?; } } ctx.begin_block()?; self.emit_block(ctx, &arm.body, ret_kind)?; ctx.end_block(scope)?; } ControlFlow::Match { source, arms } => { let scope = ctx.enter_scope(); write!(ctx.out, "{}match ", &ctx.indent)?; self.emit_source(ctx, *source, arms[0].constraint)?; ctx.begin_block()?; for arm in arms.iter() { let scope = ctx.enter_scope(); write!(ctx.out, "{}", &ctx.indent)?; self.emit_constraint(ctx, *source, arm)?; write!(ctx.out, " =>")?; ctx.begin_block()?; self.emit_block(ctx, &arm.body, ret_kind)?; ctx.end_block(scope)?; } // Always add a catchall, because we don't do exhaustiveness checking on the // match arms. writeln!(ctx.out, "{}_ => {{}}", &ctx.indent)?; ctx.end_block(scope)?; } ControlFlow::Equal { a, b, body } => { let scope = ctx.enter_scope(); write!(ctx.out, "{}if ", &ctx.indent)?; self.emit_expr(ctx, *a)?; write!(ctx.out, " == ")?; self.emit_expr(ctx, *b)?; ctx.begin_block()?; self.emit_block(ctx, body, ret_kind)?; ctx.end_block(scope)?; } ControlFlow::Loop { result, body } => { let source = match &ctx.ruleset.bindings[result.index()] { Binding::Iterator { source } => source, _ => unreachable!("Loop from a non-Iterator"), }; let scope = ctx.enter_scope(); write!(ctx.out, "{}let mut v{} = ", &ctx.indent, source.index())?; self.emit_expr(ctx, *source)?; writeln!(ctx.out, ";")?; write!( ctx.out, "{}while let Some(v{}) = v{}.next(ctx)", &ctx.indent, result.index(), source.index() )?; ctx.is_bound.insert(*result); ctx.begin_block()?; self.emit_block(ctx, body, ret_kind)?; ctx.end_block(scope)?; } &ControlFlow::Return { pos, result } => { writeln!( ctx.out, "{}// Rule at {}.", &ctx.indent, pos.pretty_print_line(&self.typeenv.filenames) )?; write!(ctx.out, "{}", &ctx.indent)?; match ret_kind { ReturnKind::Plain => write!(ctx.out, "return ")?, ReturnKind::Option => write!(ctx.out, "return Some(")?, ReturnKind::Iterator => write!(ctx.out, "returns.push(")?, } self.emit_expr(ctx, result)?; if ctx.is_ref.contains(&result) { write!(ctx.out, ".clone()")?; } match ret_kind { ReturnKind::Plain => writeln!(ctx.out, ";")?, ReturnKind::Option | ReturnKind::Iterator => writeln!(ctx.out, ");")?, } } } } Ok(()) } fn emit_expr(&self, ctx: &mut BodyContext, result: BindingId) -> std::fmt::Result { if ctx.is_bound.contains(&result) { return write!(ctx.out, "v{}", result.index()); } let binding = &ctx.ruleset.bindings[result.index()]; let mut call = |term: TermId, parameters: &[BindingId], get_sig: fn(&Term, &TypeEnv) -> Option| { let termdata = &self.termenv.terms[term.index()]; let sig = get_sig(termdata, self.typeenv).unwrap(); if let &[ret_ty] = &sig.ret_tys[..] { let (is_ref, _) = self.ty(ret_ty); if is_ref { ctx.set_ref(result, true); write!(ctx.out, "&")?; } } write!(ctx.out, "{}(ctx", sig.full_name)?; debug_assert_eq!(parameters.len(), sig.param_tys.len()); for (¶meter, &arg_ty) in parameters.iter().zip(sig.param_tys.iter()) { let (is_ref, _) = self.ty(arg_ty); write!(ctx.out, ", ")?; let (before, after) = match (is_ref, ctx.is_ref.contains(¶meter)) { (false, true) => ("", ".clone()"), (true, false) => ("&", ""), _ => ("", ""), }; write!(ctx.out, "{}", before)?; self.emit_expr(ctx, parameter)?; write!(ctx.out, "{}", after)?; } write!(ctx.out, ")") }; match binding { &Binding::ConstInt { val, ty } => self.emit_int(ctx, val, ty), Binding::ConstPrim { val } => write!(ctx.out, "{}", &self.typeenv.syms[val.index()]), Binding::Argument { index } => write!(ctx.out, "arg{}", index.index()), Binding::Extractor { term, parameter } => { call(*term, std::slice::from_ref(parameter), Term::extractor_sig) } Binding::Constructor { term, parameters, .. } => call(*term, ¶meters[..], Term::constructor_sig), Binding::MakeVariant { ty, variant, fields, } => { let (name, variants) = match &self.typeenv.types[ty.index()] { Type::Enum { name, variants, .. } => (name, variants), _ => unreachable!("MakeVariant with primitive type"), }; let variant = &variants[variant.index()]; write!( ctx.out, "{}::{}", &self.typeenv.syms[name.index()], &self.typeenv.syms[variant.name.index()] )?; if !fields.is_empty() { ctx.begin_block()?; for (field, value) in variant.fields.iter().zip(fields.iter()) { write!( ctx.out, "{}{}: ", &ctx.indent, &self.typeenv.syms[field.name.index()], )?; self.emit_expr(ctx, *value)?; if ctx.is_ref.contains(value) { write!(ctx.out, ".clone()")?; } writeln!(ctx.out, ",")?; } ctx.end_block_without_newline()?; } Ok(()) } &Binding::MatchSome { source } => { self.emit_expr(ctx, source)?; write!(ctx.out, "?") } &Binding::MatchTuple { source, field } => { self.emit_expr(ctx, source)?; write!(ctx.out, ".{}", field.index()) } // These are not supposed to happen. If they do, make the generated code fail to compile // so this is easier to debug than if we panic during codegen. &Binding::MatchVariant { source, field, .. } => { self.emit_expr(ctx, source)?; write!(ctx.out, ".{} /*FIXME*/", field.index()) } &Binding::Iterator { source } => { self.emit_expr(ctx, source)?; write!(ctx.out, ".next() /*FIXME*/") } } } fn emit_source( &self, ctx: &mut BodyContext, source: BindingId, constraint: Constraint, ) -> std::fmt::Result { if let Constraint::Variant { .. } = constraint { if !ctx.is_ref.contains(&source) { write!(ctx.out, "&")?; } } self.emit_expr(ctx, source) } fn emit_constraint( &self, ctx: &mut BodyContext, source: BindingId, arm: &MatchArm, ) -> std::fmt::Result { let MatchArm { constraint, bindings, .. } = arm; for binding in bindings.iter() { if let &Some(binding) = binding { ctx.is_bound.insert(binding); } } match *constraint { Constraint::ConstInt { val, ty } => self.emit_int(ctx, val, ty), Constraint::ConstPrim { val } => { write!(ctx.out, "{}", &self.typeenv.syms[val.index()]) } Constraint::Variant { ty, variant, .. } => { let (name, variants) = match &self.typeenv.types[ty.index()] { Type::Enum { name, variants, .. } => (name, variants), _ => unreachable!("Variant constraint on primitive type"), }; let variant = &variants[variant.index()]; write!( ctx.out, "&{}::{}", &self.typeenv.syms[name.index()], &self.typeenv.syms[variant.name.index()] )?; if !bindings.is_empty() { ctx.begin_block()?; let mut skipped_some = false; for (&binding, field) in bindings.iter().zip(variant.fields.iter()) { if let Some(binding) = binding { write!( ctx.out, "{}{}: ", &ctx.indent, &self.typeenv.syms[field.name.index()] )?; let (is_ref, _) = self.ty(field.ty); if is_ref { ctx.set_ref(binding, true); write!(ctx.out, "ref ")?; } writeln!(ctx.out, "v{},", binding.index())?; } else { skipped_some = true; } } if skipped_some { writeln!(ctx.out, "{}..", &ctx.indent)?; } ctx.end_block_without_newline()?; } Ok(()) } Constraint::Some => { write!(ctx.out, "Some(")?; if let Some(binding) = bindings[0] { ctx.set_ref(binding, ctx.is_ref.contains(&source)); write!(ctx.out, "v{}", binding.index())?; } else { write!(ctx.out, "_")?; } write!(ctx.out, ")") } } } fn emit_int( &self, ctx: &mut BodyContext, val: i128, ty: TypeId, ) -> Result<(), std::fmt::Error> { // For the kinds of situations where we use ISLE, magic numbers are // much more likely to be understandable if they're in hex rather than // decimal. // TODO: use better type info (https://github.com/bytecodealliance/wasmtime/issues/5431) if val < 0 && self.typeenv.types[ty.index()] .name(self.typeenv) .starts_with('i') { write!(ctx.out, "-{:#X}", -val) } else { write!(ctx.out, "{:#X}", val) } } }