Index: lib/Target/WebAssembly/Disassembler/WebAssemblyDisassembler.cpp =================================================================== --- lib/Target/WebAssembly/Disassembler/WebAssemblyDisassembler.cpp +++ lib/Target/WebAssembly/Disassembler/WebAssemblyDisassembler.cpp @@ -93,6 +93,7 @@ const MCOperandInfo &Info = Desc.OpInfo[i]; switch (Info.OperandType) { case MCOI::OPERAND_IMMEDIATE: + case WebAssembly::OPERAND_TABLE: case WebAssembly::OPERAND_P2ALIGN: case WebAssembly::OPERAND_BASIC_BLOCK: { if (Pos + sizeof(uint64_t) > Bytes.size()) Index: lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h =================================================================== --- lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h +++ lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h @@ -49,7 +49,9 @@ /// 64-bit floating-point immediates. OPERAND_FP64IMM, /// p2align immediate for load and store address alignment. - OPERAND_P2ALIGN + OPERAND_P2ALIGN, + /// table immediate to specify which indirect call table to use. + OPERAND_TABLE }; /// WebAssembly-specific directive identifiers. Index: lib/Target/WebAssembly/WebAssemblyFastISel.cpp =================================================================== --- lib/Target/WebAssembly/WebAssemblyFastISel.cpp +++ lib/Target/WebAssembly/WebAssemblyFastISel.cpp @@ -658,7 +658,7 @@ Call->getFunctionType()->isVarArg()) return false; - Function *Func = Call->getCalledFunction(); + const Function *Func = dyn_cast(Call->getCalledValue()->stripPointerCasts()); if (Func && Func->isIntrinsic()) return false; @@ -754,6 +754,20 @@ if (!IsVoid) MIB.addReg(ResultReg, RegState::Define); + // Add the table operand, default to zero + if (!IsDirect) { + if (MDNode *MD = Call->getMetadata("wasm.index")) + MIB.addImm(cast(MD->getOperand(0)) + ->getValue() + ->getUniqueInteger() + .getSExtValue()); + else { + if (I->getModule()->getNamedMetadata("wasm.index")) + dbgs() << "No indirect index is assigned: " << I->getFunction()->getName() << ": " << *I << "\n"; + MIB.addImm(0); + } + } + if (IsDirect) MIB.addGlobalAddress(Func); else Index: lib/Target/WebAssembly/WebAssemblyISelLowering.cpp =================================================================== --- lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -401,6 +401,23 @@ // Compute the operands for the CALLn node. SmallVector Ops; Ops.push_back(Chain); + + // Add the indirect call table operand, default to zero + if (CLI.CS && !isa(CLI.CS->getCalledValue()->stripPointerCasts())) { + APInt TableOp; + const Instruction *I = CLI.CS->getInstruction(); + if (MDNode *MD = I->getMetadata("wasm.index")) { + TableOp = cast(MD->getOperand(0)) + ->getValue() + ->getUniqueInteger(); + } else { + if (I->getModule()->getNamedMetadata("wasm.index")) + dbgs() << "No indirect index is assigned: " << I->getFunction()->getName() << ": " << *I << "\n"; + TableOp = APInt::getNullValue(64); + } + Ops.push_back(DAG.getConstant(TableOp, DL, MVT::i64)); + } + Ops.push_back(Callee); // Add all fixed arguments. Note that for non-varargs calls, NumFixedArgs Index: lib/Target/WebAssembly/WebAssemblyInstrCall.td =================================================================== --- lib/Target/WebAssembly/WebAssemblyInstrCall.td +++ lib/Target/WebAssembly/WebAssemblyInstrCall.td @@ -29,9 +29,10 @@ def CALL_#vt : I<(outs vt:$dst), (ins i32imm:$callee, variable_ops), [(set vt:$dst, (WebAssemblycall1 (i32 imm:$callee)))], !strconcat(prefix, "call\t$dst, $callee")>; - def CALL_INDIRECT_#vt : I<(outs vt:$dst), (ins I32:$callee, variable_ops), - [(set vt:$dst, (WebAssemblycall1 I32:$callee))], - !strconcat(prefix, "call_indirect\t$dst, $callee")>; + def CALL_INDIRECT_#vt : I<(outs vt:$dst), + (ins table_op:$table, I32:$callee, variable_ops), + [(set vt:$dst, (WebAssemblycall1 imm:$table, I32:$callee))], + !strconcat(prefix, "call_indirect.$table\t$dst, $callee")>; } multiclass SIMD_CALL { @@ -59,10 +60,11 @@ def CALL_VOID : I<(outs), (ins i32imm:$callee, variable_ops), [(WebAssemblycall0 (i32 imm:$callee))], - "call \t$callee">; - def CALL_INDIRECT_VOID : I<(outs), (ins I32:$callee, variable_ops), - [(WebAssemblycall0 I32:$callee)], - "call_indirect\t$callee">; + "call\t$callee">; + def CALL_INDIRECT_VOID : I<(outs), + (ins table_op:$table, I32:$callee, variable_ops), + [(WebAssemblycall0 imm:$table, I32:$callee)], + "call_indirect.$table\t$callee">; } // Uses = [SP32,SP64], isCall = 1 } // Defs = [ARGUMENTS] Index: lib/Target/WebAssembly/WebAssemblyInstrInfo.td =================================================================== --- lib/Target/WebAssembly/WebAssemblyInstrInfo.td +++ lib/Target/WebAssembly/WebAssemblyInstrInfo.td @@ -83,6 +83,9 @@ } } // OperandType = "OPERAND_P2ALIGN" +let OperandType = "OPERAND_TABLE" in +def table_op : Operand; + } // OperandNamespace = "WebAssembly" //===----------------------------------------------------------------------===// Index: lib/Target/WebAssembly/WebAssemblyLowerEmscriptenExceptions.cpp =================================================================== --- lib/Target/WebAssembly/WebAssemblyLowerEmscriptenExceptions.cpp +++ lib/Target/WebAssembly/WebAssemblyLowerEmscriptenExceptions.cpp @@ -408,6 +408,12 @@ AttributeSet NewCallPAL = AttributeSet::get(C, AttributesVec); NewCall->setAttributes(NewCallPAL); + // Add all metadata + SmallVector, 4> MDs; + II->getAllMetadata(MDs); + for (auto pair : MDs) { + NewCall->setMetadata(pair.first, pair.second); + } II->replaceAllUsesWith(NewCall); ToErase.push_back(II); @@ -428,6 +434,12 @@ NewCall->setCallingConv(II->getCallingConv()); NewCall->setDebugLoc(II->getDebugLoc()); NewCall->setAttributes(II->getAttributes()); + // Add all metadata + SmallVector, 4> MDs; + II->getAllMetadata(MDs); + for (auto pair : MDs) { + NewCall->setMetadata(pair.first, pair.second); + } II->replaceAllUsesWith(NewCall); ToErase.push_back(II); Index: lib/Transforms/IPO/LowerTypeTests.cpp =================================================================== --- lib/Transforms/IPO/LowerTypeTests.cpp +++ lib/Transforms/IPO/LowerTypeTests.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -259,6 +260,7 @@ ArrayRef Functions); void buildBitSetsFromFunctionsX86(ArrayRef TypeIds, ArrayRef Functions); + void labelIndirectCalls(CallInst *TypeTest, MDNode *MD); void buildBitSetsFromFunctionsWASM(ArrayRef TypeIds, ArrayRef Functions); void buildBitSetsFromDisjointSet(ArrayRef TypeIds, @@ -810,39 +812,104 @@ ConstantArray::get(JumpTableType, JumpTableEntries)); } -/// Assign a dummy layout using an incrementing counter, tag each function -/// with its index represented as metadata, and lower each type test to an -/// integer range comparison. During generation of the indirect function call -/// table in the backend, it will assign the given indexes. +static bool labelCallSite(CallSite CS, MDNode *MD) { + if (!isa(CS.getCalledValue()->stripPointerCasts())) { + CS.getInstruction()->setMetadata("wasm.index", MD); + return true; + } + return false; +} + +void LowerTypeTests::labelIndirectCalls(CallInst *TypeTest, MDNode *MD) { + // Stop after the grandparent node that has all casts/GEPs stripped off + Value *Parent = TypeTest->getOperand(0)->stripPointerCasts(); + // After linking and optimizations, the type test may be on a direct call + if (isa(Parent)) + return; + + CallSite CS; + Value *Current = TypeTest; + // Search all intermediate levels between the type test and the grandparent + do { + // Iterate up to the next higher level + if (isa(Current)) + Current = cast(Current)->getOperand(0); + else if (isa(Current)) + Current = cast(Current)->getPointerOperand(); + else if (Current == TypeTest) + Current = TypeTest->getOperand(0); + else { + dbgs() << "Unknown: " << *Current << "\n"; + break; + } + + // Search siblings to find the original call site + for (User *U : Current->users()) { + // Skip our own branch + if (U == TypeTest->getOperand(0)) + continue; + // Label the indirect call site + else if ((CS = CallSite(U))) { + labelCallSite(CS, MD); + // The call site may be farther down the branch (e.g. bitcast) + } else if (isa(U) || isa(U)) { + for (User *Child : U->users()) + if ((CS = CallSite(Child))) + if (!labelCallSite(CS, MD) && !isa(U)) + dbgs() << "Unknown User: " << *U << "\n"; + } + } + } while (Current != Parent); +} + +/// Assign each disjoint set to a different indirect function call table using +/// an incrementing counter. Then, eliminate explicit type tests. This is safe +/// because the type homogeneity of each disjoint set is checked at load-time. +/// Additionally, calls to out of bounds table entries will trap at runtime. /// Note: Dynamic linking is not supported, as the WebAssembly ABI has not yet /// been finalized. void LowerTypeTests::buildBitSetsFromFunctionsWASM( ArrayRef TypeIds, ArrayRef Functions) { assert(!Functions.empty()); - - // Build consecutive monotonic integer ranges for each call target set - DenseMap GlobalLayout; - + NamedMDNode *NamedMD = M->getOrInsertNamedMetadata("wasm.index"); + // Create metadata that contains the table index + MDNode *MD = MDNode::get(M->getContext(), + ArrayRef(ConstantAsMetadata::get( + ConstantInt::get(Int64Ty, IndirectIndex)))); + NamedMD->addOperand(MD); + + // Tag functions with their table index for (Function *F : Functions) { // Skip functions that are not address taken, to avoid bloating the table if (!F->hasAddressTaken()) continue; - - // Store metadata with the index for each function - MDNode *MD = MDNode::get(F->getContext(), - ArrayRef(ConstantAsMetadata::get( - ConstantInt::get(Int64Ty, IndirectIndex)))); F->setMetadata("wasm.index", MD); + } - // Assign the counter value - GlobalLayout[F] = IndirectIndex++; + // Eliminate the explicit type test and tag the call sites + for (Metadata *TypeId : TypeIds) { + DEBUG({ + if (auto MDS = dyn_cast(TypeId)) + dbgs() << MDS->getString() << ": "; + else + dbgs() << ": "; + dbgs() << IndirectIndex << "\n"; + }); + + for (CallInst *TypeTest : TypeTestCallSites[TypeId]) { + labelIndirectCalls(TypeTest, MD); + + // Replace each type test with true, and let simplifycfg remove the branch + ConstantInt *True = ConstantInt::getTrue(TypeTest->getContext()); + TypeTest->replaceAllUsesWith(True); + TypeTest->eraseFromParent(); + + ++NumTypeTestCallsLowered; + } } - // The indirect function table index space starts at zero, so pass a NULL - // pointer as the subtracted "jump table" offset. - lowerTypeTestCalls(TypeIds, - ConstantPointerNull::get(cast(Int32PtrTy)), - GlobalLayout); + // Increment the counter to put each disjoint set in a different table + IndirectIndex += 1; } void LowerTypeTests::buildBitSetsFromDisjointSet( @@ -1050,6 +1117,7 @@ LTT->Int64Ty = Type::getInt64Ty(M.getContext()); LTT->IntPtrTy = DL.getIntPtrType(M.getContext(), 0); LTT->TypeTestCallSites.clear(); + // In WebAssembly, the default table is at index 0, so start creating at 1 LTT->IndirectIndex = 1; } Index: lib/Transforms/Utils/LowerInvoke.cpp =================================================================== --- lib/Transforms/Utils/LowerInvoke.cpp +++ lib/Transforms/Utils/LowerInvoke.cpp @@ -56,6 +56,12 @@ NewCall->setCallingConv(II->getCallingConv()); NewCall->setAttributes(II->getAttributes()); NewCall->setDebugLoc(II->getDebugLoc()); + // Add all metadata + SmallVector, 4> MDs; + II->getAllMetadata(MDs); + for (auto pair : MDs) { + NewCall->setMetadata(pair.first, pair.second); + } II->replaceAllUsesWith(NewCall); // Insert an unconditional branch to the normal destination. Index: test/CodeGen/WebAssembly/call.ll =================================================================== --- test/CodeGen/WebAssembly/call.ll +++ test/CodeGen/WebAssembly/call.ll @@ -80,7 +80,7 @@ ; CHECK-LABEL: call_indirect_void: ; CHECK-NEXT: .param i32{{$}} -; CHECK-NEXT: {{^}} call_indirect $0{{$}} +; CHECK-NEXT: {{^}} call_indirect.0 $0{{$}} ; CHECK-NEXT: return{{$}} define void @call_indirect_void(void ()* %callee) { call void %callee() @@ -90,7 +90,7 @@ ; CHECK-LABEL: call_indirect_i32: ; CHECK-NEXT: .param i32{{$}} ; CHECK-NEXT: .result i32{{$}} -; CHECK-NEXT: {{^}} i32.call_indirect $push[[NUM:[0-9]+]]=, $0{{$}} +; CHECK-NEXT: {{^}} i32.call_indirect.0 $push[[NUM:[0-9]+]]=, $0{{$}} ; CHECK-NEXT: return $pop[[NUM]]{{$}} define i32 @call_indirect_i32(i32 ()* %callee) { %t = call i32 %callee() Index: test/CodeGen/WebAssembly/cfi.ll =================================================================== --- test/CodeGen/WebAssembly/cfi.ll +++ test/CodeGen/WebAssembly/cfi.ll @@ -5,6 +5,12 @@ target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128" target triple = "wasm32-unknown-unknown" +%struct._IO_FILE1 = type { i32 } +%struct._IO_FILE2 = type opaque + +@i = internal global %struct._IO_FILE1 { i32 0 } +@stdout = hidden local_unnamed_addr constant %struct._IO_FILE1* @i + @0 = private unnamed_addr constant [2 x void (...)*] [void (...)* bitcast (void ()* @f to void (...)*), void (...)* bitcast (void ()* @g to void (...)*)], align 16 ; CHECK-LABEL: h: @@ -32,9 +38,8 @@ declare void @llvm.trap() nounwind noreturn ; CHECK-LABEL: foo: -; CHECK: br_if -; CHECK: br_if -; CHECK: unreachable +; CHECK-NOT: br_if +; CHECK-NOT: unreachable define i1 @foo(i8* %p) { %x = call i1 @llvm.type.test(i8* %p, metadata !"typeid1") br i1 %x, label %contx, label %trap @@ -51,3 +56,23 @@ %z = add i1 %x, %y ret i1 %z } + +define i32 @a(%struct._IO_FILE2* %f) { + ret i32 0 +} + +define i32 @b(i32 %f) { + %1 = add i32 %f, 1 + ret i32 %1 +} + +; CHECK-LABEL: c: +; CHECK-NOT: i32.call_indirect +; CHECK: i32.call +; CHECK: i32.call +define i32 @c() { + %1 = load %struct._IO_FILE1*, %struct._IO_FILE1** @stdout + %call0 = call i32 bitcast (i32 (%struct._IO_FILE2*)* @a to i32 (%struct._IO_FILE1*)*)(%struct._IO_FILE1* %1) + %call1 = call i32 @b(i32 %call0) + ret i32 %call1 +} Index: test/Transforms/LowerTypeTests/function-disjoint.ll =================================================================== --- test/Transforms/LowerTypeTests/function-disjoint.ll +++ test/Transforms/LowerTypeTests/function-disjoint.ll @@ -32,12 +32,14 @@ define i1 @foo(i8* %p) { ; X64: icmp eq i64 {{.*}}, ptrtoint ([1 x <{ i8, i32, i8, i8, i8 }>]* @[[JT0]] to i64) - ; WASM32: icmp eq i64 {{.*}}, 1 + ; WASM32-NOT: icmp %x = call i1 @llvm.type.test(i8* %p, metadata !"typeid1") ; X64: icmp eq i64 {{.*}}, ptrtoint ([1 x <{ i8, i32, i8, i8, i8 }>]* @[[JT1]] to i64) - ; WASM32: icmp eq i64 {{.*}}, 2 + ; WASM32-NOT: icmp %y = call i1 @llvm.type.test(i8* %p, metadata !"typeid2") %z = add i1 %x, %y + ; WASM32-NOT: br i1 {{.*}} label %cont, label %trap + ; WASM32: add i1 true, true ret i1 %z } Index: test/Transforms/LowerTypeTests/function-ext.ll =================================================================== --- test/Transforms/LowerTypeTests/function-ext.ll +++ test/Transforms/LowerTypeTests/function-ext.ll @@ -12,8 +12,9 @@ define i1 @bar(i8* %ptr) { ; X64: icmp eq i64 {{.*}}, ptrtoint ([1 x <{ i8, i32, i8, i8, i8 }>]* @[[JT]] to i64) - ; WASM32: sub i64 {{.*}}, 0 - ; WASM32: icmp ult i64 {{.*}}, 1 + ; WASM32-NOT: sub i64 {{.*}}, 0 + ; WASM32-NOT: icmp ult i64 {{.*}}, 1 + ; WASM32: ret i1 true %p = call i1 @llvm.type.test(i8* %ptr, metadata !"void") ret i1 %p } Index: test/Transforms/LowerTypeTests/function.ll =================================================================== --- test/Transforms/LowerTypeTests/function.ll +++ test/Transforms/LowerTypeTests/function.ll @@ -13,13 +13,13 @@ ; X64: @g = alias void (), bitcast (<{ i8, i32, i8, i8, i8 }>* getelementptr inbounds ([2 x <{ i8, i32, i8, i8, i8 }>], [2 x <{ i8, i32, i8, i8, i8 }>]* @[[JT]], i64 0, i64 1) to void ()*) ; X64: define private void @[[FNAME]]() -; WASM32: define void @f() !type !{{[0-9]+}} !wasm.index ![[I0:[0-9]+]] +; WASM32: define void @f() !type !{{[0-9]+}} !wasm.index ![[I:[0-9]+]] define void @f() !type !0 { ret void } ; X64: define private void @[[GNAME]]() -; WASM32: define void @g() !type !{{[0-9]+}} !wasm.index ![[I1:[0-9]+]] +; WASM32: define void @g() !type !{{[0-9]+}} !wasm.index ![[I]] define void @g() !type !0 { ret void } @@ -30,11 +30,12 @@ define i1 @foo(i8* %p) { ; X64: sub i64 {{.*}}, ptrtoint ([2 x <{ i8, i32, i8, i8, i8 }>]* @[[JT]] to i64) - ; WASM32: sub i64 {{.*}}, 1 - ; WASM32: icmp ult i64 {{.*}}, 2 + ; WASM32-NOT: sub + ; WASM32-NOT: icmp + ; WASM32-NOT: br i1 {{.*}} label %cont, label %trap + ; WASM32: ret i1 true %x = call i1 @llvm.type.test(i8* %p, metadata !"typeid1") ret i1 %x } -; WASM32: ![[I0]] = !{i64 1} -; WASM32: ![[I1]] = !{i64 2} +; WASM32: ![[I]] = !{i64 1}