use crate::names; use proc_macro2::{Literal, TokenStream}; use quote::quote; use witx::Layout; pub(super) fn define_variant( name: &witx::Id, v: &witx::Variant, derive_std_error: bool, ) -> TokenStream { let ident = names::type_(name); let size = v.mem_size_align().size as u32; let align = v.mem_size_align().align; let contents_offset = v.payload_offset() as u32; let lifetime = quote!('a); let tag_ty = super::int_repr_tokens(v.tag_repr); let variants = v.cases.iter().map(|c| { let var_name = names::enum_variant(&c.name); if let Some(tref) = &c.tref { let var_type = names::type_ref(&tref, lifetime.clone()); quote!(#var_name(#var_type)) } else { quote!(#var_name) } }); let read_variant = v.cases.iter().enumerate().map(|(i, c)| { let i = Literal::usize_unsuffixed(i); let variantname = names::enum_variant(&c.name); if let Some(tref) = &c.tref { let varianttype = names::type_ref(tref, lifetime.clone()); quote! { #i => { let variant_ptr = location.cast::().add(#contents_offset)?; let variant_val = <#varianttype as wiggle::GuestType>::read(mem, variant_ptr.cast())?; Ok(#ident::#variantname(variant_val)) } } } else { quote! { #i => Ok(#ident::#variantname), } } }); let write_variant = v.cases.iter().enumerate().map(|(i, c)| { let variantname = names::enum_variant(&c.name); let write_tag = quote! { mem.write(location.cast(), #i as #tag_ty)?; }; if let Some(tref) = &c.tref { let varianttype = names::type_ref(tref, lifetime.clone()); quote! { #ident::#variantname(contents) => { #write_tag let variant_ptr = location.cast::().add(#contents_offset)?; <#varianttype as wiggle::GuestType>::write(mem, variant_ptr.cast(), contents)?; } } } else { quote! { #ident::#variantname => { #write_tag } } } }); let mut extra_derive = quote!(); let enum_try_from = if v.cases.iter().all(|c| c.tref.is_none()) { let tryfrom_repr_cases = v.cases.iter().enumerate().map(|(i, c)| { let variant_name = names::enum_variant(&c.name); let n = Literal::usize_unsuffixed(i); quote!(#n => Ok(#ident::#variant_name)) }); let abi_ty = names::wasm_type(v.tag_repr.into()); extra_derive = quote!(, Copy); quote! { impl TryFrom<#tag_ty> for #ident { type Error = wiggle::GuestError; #[inline] fn try_from(value: #tag_ty) -> Result<#ident, wiggle::GuestError> { match value { #(#tryfrom_repr_cases),*, _ => Err(wiggle::GuestError::InvalidEnumValue(stringify!(#ident))), } } } impl TryFrom<#abi_ty> for #ident { type Error = wiggle::GuestError; #[inline] fn try_from(value: #abi_ty) -> Result<#ident, wiggle::GuestError> { #ident::try_from(#tag_ty::try_from(value)?) } } } } else { quote!() }; let enum_from = if v.cases.iter().all(|c| c.tref.is_none()) { let from_repr_cases = v.cases.iter().enumerate().map(|(i, c)| { let variant_name = names::enum_variant(&c.name); let n = Literal::usize_unsuffixed(i); quote!(#ident::#variant_name => #n) }); quote! { impl From<#ident> for #tag_ty { #[inline] fn from(v: #ident) -> #tag_ty { match v { #(#from_repr_cases),*, } } } } } else { quote!() }; let extra_derive = quote!(, PartialEq #extra_derive); let error_impls = if derive_std_error { quote! { impl std::fmt::Display for #ident { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) } } impl std::error::Error for #ident {} } } else { quote!() }; quote! { #[derive(Clone, Debug #extra_derive)] pub enum #ident { #(#variants),* } #error_impls #enum_try_from #enum_from impl wiggle::GuestType for #ident { #[inline] fn guest_size() -> u32 { #size } #[inline] fn guest_align() -> usize { #align } fn read(mem: &wiggle::GuestMemory, location: wiggle::GuestPtr) -> Result { let tag = mem.read(location.cast::<#tag_ty>())?; match tag { #(#read_variant)* _ => Err(wiggle::GuestError::InvalidEnumValue(stringify!(#ident))), } } fn write(mem: &mut wiggle::GuestMemory, location: wiggle::GuestPtr, val: Self) -> Result<(), wiggle::GuestError> { match val { #(#write_variant)* } Ok(()) } } } } impl super::WiggleType for witx::Variant { fn impls_display(&self) -> bool { false } }