# frozen_string_literal: true module SyntaxTree module YARV class Assembler class ObjectVisitor < Compiler::RubyVisitor def visit_dyna_symbol(node) if node.parts.empty? :"" else raise CompilationError end end def visit_string_literal(node) case node.parts.length when 0 "" when 1 raise CompilationError unless node.parts.first.is_a?(TStringContent) node.parts.first.value else raise CompilationError end end end CALLDATA_FLAGS = { "ARGS_SPLAT" => CallData::CALL_ARGS_SPLAT, "ARGS_BLOCKARG" => CallData::CALL_ARGS_BLOCKARG, "FCALL" => CallData::CALL_FCALL, "VCALL" => CallData::CALL_VCALL, "ARGS_SIMPLE" => CallData::CALL_ARGS_SIMPLE, "BLOCKISEQ" => CallData::CALL_BLOCKISEQ, "KWARG" => CallData::CALL_KWARG, "KW_SPLAT" => CallData::CALL_KW_SPLAT, "TAILCALL" => CallData::CALL_TAILCALL, "SUPER" => CallData::CALL_SUPER, "ZSUPER" => CallData::CALL_ZSUPER, "OPT_SEND" => CallData::CALL_OPT_SEND, "KW_SPLAT_MUT" => CallData::CALL_KW_SPLAT_MUT }.freeze DEFINED_TYPES = [ nil, "nil", "instance-variable", "local-variable", "global-variable", "class variable", "constant", "method", "yield", "super", "self", "true", "false", "assignment", "expression", "ref", "func", "constant-from" ].freeze attr_reader :filepath def initialize(filepath) @filepath = filepath end def assemble iseq = InstructionSequence.new(:top, "<main>", nil, Location.default) assemble_iseq(iseq, File.readlines(filepath, chomp: true)) iseq.compile! iseq end def self.assemble(filepath) new(filepath).assemble end private def assemble_iseq(iseq, lines) labels = Hash.new { |hash, name| hash[name] = iseq.label } line_index = 0 while line_index < lines.length line = lines[line_index] line_index += 1 case line.strip when "", /^;/ # skip over blank lines and comments next when /^(\w+):$/ # create labels iseq.push(labels[$1]) next when /^__END__/ # skip over the rest of the file when we hit __END__ return end insn, operands = line.split(" ", 2) case insn when "adjuststack" iseq.adjuststack(parse_number(operands)) when "anytostring" iseq.anytostring when "branchif" iseq.branchif(labels[operands]) when "branchnil" iseq.branchnil(labels[operands]) when "branchunless" iseq.branchunless(labels[operands]) when "checkkeyword" kwbits_index, keyword_index = operands.split(/,\s*/) iseq.checkkeyword( parse_number(kwbits_index), parse_number(keyword_index) ) when "checkmatch" iseq.checkmatch(parse_number(operands)) when "checktype" iseq.checktype(parse_number(operands)) when "concatarray" iseq.concatarray when "concatstrings" iseq.concatstrings(parse_number(operands)) when "defineclass" body = parse_nested(lines[line_index..]) line_index += body.length name_value, flags_value = operands.split(/,\s*/) name = parse_symbol(name_value) flags = parse_number(flags_value) class_iseq = iseq.class_child_iseq(name.to_s, Location.default) assemble_iseq(class_iseq, body) iseq.defineclass(name, class_iseq, flags) when "defined" type, object, message = operands.split(/,\s*/) iseq.defined( DEFINED_TYPES.index(type), parse_symbol(object), parse_string(message) ) when "definemethod" body = parse_nested(lines[line_index..]) line_index += body.length name = parse_symbol(operands) method_iseq = iseq.method_child_iseq(name.to_s, Location.default) assemble_iseq(method_iseq, body) iseq.definemethod(name, method_iseq) when "definesmethod" body = parse_nested(lines[line_index..]) line_index += body.length name = parse_symbol(operands) method_iseq = iseq.method_child_iseq(name.to_s, Location.default) assemble_iseq(method_iseq, body) iseq.definesmethod(name, method_iseq) when "dup" iseq.dup when "dupn" iseq.dupn(parse_number(operands)) when "duparray" iseq.duparray(parse_type(operands, Array)) when "duphash" iseq.duphash(parse_type(operands, Hash)) when "expandarray" number, flags = operands.split(/,\s*/) iseq.expandarray(parse_number(number), parse_number(flags)) when "getblockparam" lookup = find_local(iseq, operands) iseq.getblockparam(lookup.index, lookup.level) when "getblockparamproxy" lookup = find_local(iseq, operands) iseq.getblockparamproxy(lookup.index, lookup.level) when "getclassvariable" iseq.getclassvariable(parse_symbol(operands)) when "getconstant" iseq.getconstant(parse_symbol(operands)) when "getglobal" iseq.getglobal(parse_symbol(operands)) when "getinstancevariable" iseq.getinstancevariable(parse_symbol(operands)) when "getlocal" lookup = find_local(iseq, operands) iseq.getlocal(lookup.index, lookup.level) when "getspecial" key, type = operands.split(/,\s*/) iseq.getspecial(parse_number(key), parse_number(type)) when "intern" iseq.intern when "invokeblock" iseq.invokeblock( operands ? parse_calldata(operands) : YARV.calldata(nil, 0) ) when "invokesuper" calldata = if operands parse_calldata(operands) else YARV.calldata( nil, 0, CallData::CALL_FCALL | CallData::CALL_ARGS_SIMPLE | CallData::CALL_SUPER ) end block_iseq = if lines[line_index].start_with?(" ") body = parse_nested(lines[line_index..]) line_index += body.length block_iseq = iseq.block_child_iseq(Location.default) assemble_iseq(block_iseq, body) block_iseq end iseq.invokesuper(calldata, block_iseq) when "jump" iseq.jump(labels[operands]) when "leave" iseq.leave when "newarray" iseq.newarray(parse_number(operands)) when "newarraykwsplat" iseq.newarraykwsplat(parse_number(operands)) when "newhash" iseq.newhash(parse_number(operands)) when "newrange" iseq.newrange(parse_options(operands, [0, 1])) when "nop" iseq.nop when "objtostring" iseq.objtostring(YARV.calldata(:to_s)) when "once" block_iseq = if lines[line_index].start_with?(" ") body = parse_nested(lines[line_index..]) line_index += body.length block_iseq = iseq.block_child_iseq(Location.default) assemble_iseq(block_iseq, body) block_iseq end iseq.once(block_iseq, iseq.inline_storage) when "opt_and" iseq.send(YARV.calldata(:&, 1)) when "opt_aref" iseq.send(YARV.calldata(:[], 1)) when "opt_aref_with" iseq.opt_aref_with(parse_string(operands), YARV.calldata(:[], 1)) when "opt_aset" iseq.send(YARV.calldata(:[]=, 2)) when "opt_aset_with" iseq.opt_aset_with(parse_string(operands), YARV.calldata(:[]=, 2)) when "opt_case_dispatch" cdhash_value, else_label_value = operands.split(/\s*\},\s*/) cdhash_value.sub!(/\A\{/, "") pairs = cdhash_value .split(/\s*,\s*/) .map! { |pair| pair.split(/\s*=>\s*/) } cdhash = pairs.to_h { |value, nm| [parse(value), labels[nm]] } else_label = labels[else_label_value] iseq.opt_case_dispatch(cdhash, else_label) when "opt_div" iseq.send(YARV.calldata(:/, 1)) when "opt_empty_p" iseq.send(YARV.calldata(:empty?)) when "opt_eq" iseq.send(YARV.calldata(:==, 1)) when "opt_ge" iseq.send(YARV.calldata(:>=, 1)) when "opt_gt" iseq.send(YARV.calldata(:>, 1)) when "opt_getconstant_path" iseq.opt_getconstant_path(parse_type(operands, Array)) when "opt_le" iseq.send(YARV.calldata(:<=, 1)) when "opt_length" iseq.send(YARV.calldata(:length)) when "opt_lt" iseq.send(YARV.calldata(:<, 1)) when "opt_ltlt" iseq.send(YARV.calldata(:<<, 1)) when "opt_minus" iseq.send(YARV.calldata(:-, 1)) when "opt_mod" iseq.send(YARV.calldata(:%, 1)) when "opt_mult" iseq.send(YARV.calldata(:*, 1)) when "opt_neq" iseq.send(YARV.calldata(:!=, 1)) when "opt_newarray_max" iseq.newarray(parse_number(operands)) iseq.send(YARV.calldata(:max)) when "opt_newarray_min" iseq.newarray(parse_number(operands)) iseq.send(YARV.calldata(:min)) when "opt_nil_p" iseq.send(YARV.calldata(:nil?)) when "opt_not" iseq.send(YARV.calldata(:!)) when "opt_or" iseq.send(YARV.calldata(:|, 1)) when "opt_plus" iseq.send(YARV.calldata(:+, 1)) when "opt_regexpmatch2" iseq.send(YARV.calldata(:=~, 1)) when "opt_reverse" iseq.send(YARV.calldata(:reverse)) when "opt_send_without_block" iseq.send(parse_calldata(operands)) when "opt_size" iseq.send(YARV.calldata(:size)) when "opt_str_freeze" iseq.putstring(parse_string(operands)) iseq.send(YARV.calldata(:freeze)) when "opt_str_uminus" iseq.putstring(parse_string(operands)) iseq.send(YARV.calldata(:-@)) when "opt_succ" iseq.send(YARV.calldata(:succ)) when "pop" iseq.pop when "putnil" iseq.putnil when "putobject" iseq.putobject(parse(operands)) when "putself" iseq.putself when "putspecialobject" iseq.putspecialobject(parse_options(operands, [1, 2, 3])) when "putstring" iseq.putstring(parse_string(operands)) when "send" block_iseq = if lines[line_index].start_with?(" ") body = parse_nested(lines[line_index..]) line_index += body.length block_iseq = iseq.block_child_iseq(Location.default) assemble_iseq(block_iseq, body) block_iseq end iseq.send(parse_calldata(operands), block_iseq) when "setblockparam" lookup = find_local(iseq, operands) iseq.setblockparam(lookup.index, lookup.level) when "setconstant" iseq.setconstant(parse_symbol(operands)) when "setglobal" iseq.setglobal(parse_symbol(operands)) when "setlocal" lookup = find_local(iseq, operands) iseq.setlocal(lookup.index, lookup.level) when "setn" iseq.setn(parse_number(operands)) when "setclassvariable" iseq.setclassvariable(parse_symbol(operands)) when "setinstancevariable" iseq.setinstancevariable(parse_symbol(operands)) when "setspecial" iseq.setspecial(parse_number(operands)) when "splatarray" iseq.splatarray(parse_options(operands, [true, false])) when "swap" iseq.swap when "throw" iseq.throw(parse_number(operands)) when "topn" iseq.topn(parse_number(operands)) when "toregexp" options, length = operands.split(", ") iseq.toregexp(parse_number(options), parse_number(length)) when "ARG_REQ" iseq.argument_size += 1 iseq.local_table.plain(operands.to_sym) when "ARG_BLOCK" iseq.argument_options[:block_start] = iseq.argument_size iseq.local_table.block(operands.to_sym) iseq.argument_size += 1 else raise "Could not understand: #{line}" end end end def find_local(iseq, operands) name_string, level_string = operands.split(/,\s*/) name = name_string.to_sym level = level_string&.to_i || 0 iseq.local_table.plain(name) iseq.local_table.find(name, level) end def parse(value) program = SyntaxTree.parse(value) raise if program.statements.body.length != 1 program.statements.body.first.accept(ObjectVisitor.new) end def parse_options(value, options) parse(value).tap { raise unless options.include?(_1) } end def parse_type(value, type) parse(value).tap { raise unless _1.is_a?(type) } end def parse_number(value) parse_type(value, Integer) end def parse_string(value) parse_type(value, String) end def parse_symbol(value) parse_type(value, Symbol) end def parse_nested(lines) body = lines.take_while { |line| line.match?(/^($|;| )/) } body.map! { |line| line.delete_prefix!(" ") || +"" } end def parse_calldata(value) message, argc_value, flags_value = value.split flags = if flags_value flags_value.split("|").map(&CALLDATA_FLAGS).inject(:|) else CallData::CALL_ARGS_SIMPLE end YARV.calldata(message.to_sym, argc_value&.to_i || 0, flags) end end end end