From 7857ad21e1af85525e9ef09e0cbead01cdb1750c Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Wed, 13 May 2026 15:03:44 -0700 Subject: [PATCH] Handle complex types in assertion wrapper modules Continuation types cannot be passed to JS, so `assert_return` actions must be implemented with WebAssembly wrapper modules when converting tests to JS. There is precedent for this with vector and exnref types. The difference for continuations is that they are defined types, so creating the wrapper modules with the correct import types requires arbitrarily complex type sections. Add code to traverse and topologically sort the rec groups reachable from the type of an exported function or global, then map them back to surface syntax that can be emitted. --- interpreter/script/js.ml | 161 ++++++++++++++++++++++++++++++++++----- 1 file changed, 140 insertions(+), 21 deletions(-) diff --git a/interpreter/script/js.ml b/interpreter/script/js.ml index a125e9a0..ae7e3fbd 100644 --- a/interpreter/script/js.ml +++ b/interpreter/script/js.ml @@ -289,7 +289,6 @@ let lookup_export (env : env) x_opt name at = let subject_idx = 0l let hostref_idx = 1l let eq_ref_idx = 2l -let subject_type_idx = 3l let eq_of = function | I32T -> I32 I32Op.Eq @@ -331,13 +330,12 @@ let value v = ] | Ref _ -> assert false -let invoke ft vs at = - let dt = RecT [SubT (Final, [], DefFuncT ft)] in - [dt @@ at], FuncImport (subject_type_idx @@ at) @@ at, +let invoke dt vs at = + FuncImport (0l @@ at) @@ at, List.concat (List.map value vs) @ [Call (subject_idx @@ at) @@ at] -let get t at = - [], GlobalImport t @@ at, [GlobalGet (subject_idx @@ at) @@ at] +let get gt at = + GlobalImport gt @@ at, [GlobalGet (subject_idx @@ at) @@ at] let run ts at = [], [] @@ -470,7 +468,7 @@ let assert_return ress ts at = BrIf (0l @@ at) @@ at ] | RefResult (RefPat _) -> assert false - | RefResult (RefTypePat (ExnHT | ExternHT)) -> + | RefResult (RefTypePat (ExnHT | ExternHT | ContHT)) -> [ BrOnNull (0l @@ at) @@ at ] | RefResult (RefTypePat t) -> [ RefTest (NoNull, t) @@ at; @@ -503,21 +501,137 @@ let eqref = RefT (Null, EqHT) let func_rec_type ts1 ts2 at = RecT [SubT (Final, [], DefFuncT (FuncT (ts1, ts2)))] @@ at -let wrap item_name wrap_action wrap_assertion at = - let itypes, idesc, action = wrap_action at in + + +let collect_and_sort_groups root_type = + let direct_deps rt = + let deps = ref [] in + let rec visit_ht = function + | DefHT (DefT (rt', _)) -> + if rt' != rt && not (List.exists (fun r -> r == rt') !deps) then + deps := rt' :: !deps + | _ -> () + and visit_sub (SubT (_, hts, st)) = + List.iter visit_ht hts; visit_str st + and visit_str = function + | DefFuncT (FuncT (ins, outs)) -> + List.iter visit_val ins; List.iter visit_val outs + | DefContT (ContT ht) -> visit_ht ht + | DefStructT (StructT fields) -> + List.iter (fun (FieldT (_, st)) -> visit_storage st) fields + | DefArrayT (ArrayT (FieldT (_, st))) -> visit_storage st + and visit_storage = function + | ValStorageT t -> visit_val t + | _ -> () + and visit_val = function + | RefT (_, ht) -> visit_ht ht + | _ -> () + in + let RecT sts = rt in + List.iter visit_sub sts; + List.rev !deps + in + let visited = ref [] in + let sorted = ref [] in + let rec visit rt = + if not (List.exists (fun r -> r == rt) !visited) then begin + visited := rt :: !visited; + List.iter visit (direct_deps rt); + sorted := rt :: !sorted + end + in + let visit_root_ht = function + | DefHT (DefT (rt, _)) -> visit rt + | _ -> () + in + let visit_root_val = function + | RefT (_, ht) -> visit_root_ht ht + | _ -> () + in + visit_root_val root_type; + List.rev !sorted + +let wrap item_name root_type wrap_action wrap_assertion at = + let idesc, action = wrap_action at in let locals, assertion = wrap_assertion at in + let sorted_groups = collect_and_sort_groups root_type in + let base_map, total_custom_size = + List.fold_left (fun (map, idx) rt -> + let RecT sts = rt in + ((rt, idx) :: map, Int32.add idx (Int32.of_int (List.length sts))) + ) ([], 0l) sorted_groups + in + let get_base rt = + List.assq rt base_map + in + let remap_ht current_group_rt_opt = function + | DefHT (DefT (rt, i)) -> VarHT (StatX (Int32.add (get_base rt) i)) + | VarHT (RecX i) -> + (match current_group_rt_opt with + | Some rt -> VarHT (StatX (Int32.add (get_base rt) i)) + | None -> failwith "remap_ht: RecX outside group") + | ht -> ht + in + let remap_ref current_group_rt_opt (nul, ht) = + (nul, remap_ht current_group_rt_opt ht) + in + let remap_val current_group_rt_opt = function + | RefT rt -> RefT (remap_ref current_group_rt_opt rt) + | t -> t + in + let remap_storage current_group_rt_opt = function + | ValStorageT t -> ValStorageT (remap_val current_group_rt_opt t) + | st -> st + in + let remap_field current_group_rt_opt (FieldT (m, st)) = + FieldT (m, remap_storage current_group_rt_opt st) + in + let remap_func current_group_rt_opt (FuncT (ins, outs)) = + FuncT (List.map (remap_val current_group_rt_opt) ins, + List.map (remap_val current_group_rt_opt) outs) + in + let remap_str current_group_rt_opt = function + | DefFuncT ft -> DefFuncT (remap_func current_group_rt_opt ft) + | DefContT (ContT ht) -> DefContT (ContT (remap_ht current_group_rt_opt ht)) + | DefStructT (StructT fields) -> + DefStructT (StructT (List.map (remap_field current_group_rt_opt) fields)) + | DefArrayT (ArrayT f) -> + DefArrayT (ArrayT (remap_field current_group_rt_opt f)) + in + let remap_sub current_group_rt_opt (SubT (fin, hts, st)) = + SubT (fin, List.map (remap_ht current_group_rt_opt) hts, + remap_str current_group_rt_opt st) + in + let custom_types = + List.map (fun rt -> + let RecT sts = rt in + let sts' = List.map (remap_sub (Some rt)) sts in + {it = RecT sts'; at = Source.no_region} + ) sorted_groups + in + let run_type_idx = total_custom_size in + let hostref_type_idx = Int32.add total_custom_size 1l in + let eqref_type_idx = Int32.add total_custom_size 2l in let types = - func_rec_type [] [] at :: - func_rec_type [i32] [anyref] at :: - func_rec_type [eqref; eqref] [i32] at :: - itypes + custom_types @ + [ func_rec_type [] [] at; + func_rec_type [i32] [anyref] at; + func_rec_type [eqref; eqref] [i32] at; + ] + in + let idesc' = match idesc.it, root_type with + | FuncImport _, RefT (_, DefHT (DefT (rt, i))) -> + FuncImport (Int32.add (get_base rt) i @@ at) + | GlobalImport t, _ -> GlobalImport t + | _ -> idesc.it in + let idesc = {idesc with it = idesc'} in let imports = [ {module_name = Utf8.decode "module"; item_name; idesc} @@ at; {module_name = Utf8.decode "spectest"; item_name = Utf8.decode "hostref"; - idesc = FuncImport (1l @@ at) @@ at} @@ at; + idesc = FuncImport (hostref_type_idx @@ at) @@ at} @@ at; {module_name = Utf8.decode "spectest"; item_name = Utf8.decode "eq_ref"; - idesc = FuncImport (2l @@ at) @@ at} @@ at; + idesc = FuncImport (eqref_type_idx @@ at) @@ at} @@ at; ] in let item = @@ -532,7 +646,7 @@ let wrap item_name wrap_action wrap_assertion at = [ Block (ValBlockType None, action @ assertion @ [Return @@ at]) @@ at; Unreachable @@ at ] in - let funcs = [{ftype = 0l @@ at; locals; body} @@ at] in + let funcs = [{ftype = run_type_idx @@ at; locals; body} @@ at] in let m = {empty_module with types; funcs; imports; exports} @@ at in (try Valid.check_module m; (* sanity check *) @@ -553,7 +667,11 @@ let is_js_vec_type = function | _ -> false let is_js_ref_type = function - | (_, ExnHT) -> false + | (_, (ExnHT | NoExnHT | ContHT | NoContHT)) -> false + | (_, DefHT dt) -> + (match expand_def_type dt with + | DefContT _ -> false + | _ -> true) | _ -> true let is_js_val_type = function @@ -665,9 +783,9 @@ let rec of_definition def = try of_definition (snd (Parse.Module.parse_string ~offset:s.at s.it)) with Parse.Syntax _ | Custom.Syntax _ -> of_bytes "" -let of_wrapper env x_opt name wrap_action wrap_assertion at = +let of_wrapper env x_opt name root_type wrap_action wrap_assertion at = let x = of_inst_opt env x_opt in - let bs = wrap name wrap_action wrap_assertion at in + let bs = wrap name root_type wrap_action wrap_assertion at in "call(instance(module(" ^ of_bytes bs ^ "), " ^ "exports(" ^ x ^ ")), " ^ " \"run\", [])" @@ -678,11 +796,12 @@ let of_action env act = "[" ^ String.concat ", " (List.map of_value vs) ^ "])", (match lookup_export env x_opt name act.at with | ExternFuncT dt -> + let root_type = RefT (NoNull, DefHT dt) in let FuncT (_, out) as ft = as_func_str_type (expand_def_type dt) in if is_js_func_type ft then None else - Some (of_wrapper env x_opt name (invoke ft vs), out) + Some (of_wrapper env x_opt name root_type (invoke dt vs), out) | _ -> None ) | Get (x_opt, name) -> @@ -690,7 +809,7 @@ let of_action env act = (match lookup_export env x_opt name act.at with | ExternGlobalT gt when not (is_js_global_type gt) -> let GlobalT (_, t) = gt in - Some (of_wrapper env x_opt name (get gt), [t]) + Some (of_wrapper env x_opt name t (get gt), [t]) | _ -> None )