diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -226,36 +226,46 @@ [{convertType(opInst.getOperand($0).getType().cast())}]; string result = [{convertType(opInst.getResult($0).getType().cast())}]; + string structResult = + [{convertType(opInst.getResult(0).getType().cast() + .getBody()[$0])}]; } // Base class for LLVM intrinsics operation. It is similar to LLVM_Op, but -// provides the "llvmBuilder" field for constructing the intrinsic. The builder -// relies on the contents on "overloadedResults" and "overloadedOperands" lists -// that contain the positions of intrinsic results and operands that are -// overloadable in the LLVM sense, that is their types must be passed in during -// the construction of the intrinsic declaration to differentiate between -// differently-typed versions of the intrinsic. "opName" contains the name of -// the operation to be associated with the intrinsic and "enumName" contains the -// name of the intrinsic as appears in `llvm::Intrinsic` enum; one usually wants -// these to be related. +// provides the "llvmBuilder" field for constructing the intrinsic. +// The builder relies on the contents of "overloadedResults" and +// "overloadedOperands" lists that contain the positions of intrinsic results +// and operands that are overloadable in the LLVM sense, that is their types +// must be passed in during the construction of the intrinsic declaration to +// differentiate between differently-typed versions of the intrinsic. +// If the intrinsic has multiple results, this will eventually be packed into a +// single struct result. In this case, the types of any overloaded results need +// to be accessed via the LLVMStructType, instead of directly via the result. +// "opName" contains the name of the operation to be associated with the +// intrinsic and "enumName" contains the name of the intrinsic as appears in +// `llvm::Intrinsic` enum; one usually wants these to be related. class LLVM_IntrOpBase overloadedResults, list overloadedOperands, - list traits, bit hasResult> + list traits, int numResults> : LLVM_OpBase, - Results { + Results { + string resultPattern = !if(!gt(numResults, 1), + LLVM_IntrPatterns.structResult, + LLVM_IntrPatterns.result); let llvmBuilder = [{ llvm::Module *module = builder.GetInsertBlock()->getModule(); llvm::Function *fn = llvm::Intrinsic::getDeclaration( module, llvm::Intrinsic::}] # enumName # [{, { }] # StrJoin.lst, + ListIntSubst.lst, ListIntSubst.lst)>.result # [{ }); auto operands = lookupValues(opInst.getOperands()); - }] # !if(hasResult, "$res = ", "") # [{builder.CreateCall(fn, operands); + }] # !if(!gt(numResults, 0), "$res = ", "") + # [{builder.CreateCall(fn, operands); }]; } @@ -263,9 +273,10 @@ // the intrinsic into the LLVM dialect and prefixes its name with "intr.". class LLVM_IntrOp overloadedResults, list overloadedOperands, list traits, - bit hasResult> + int numResults> : LLVM_IntrOpBase; + overloadedResults, overloadedOperands, traits, + numResults>; // Base class for LLVM intrinsic operations returning no results. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -890,6 +890,33 @@ LLVM_Type:$isVolatile); } +// Intrinsics with multiple returns. + +def LLVM_SAddWithOverflowOp + : LLVM_IntrOp<"sadd.with.overflow", [0], [], [], 2> { + let arguments = (ins LLVM_Type, LLVM_Type); +} +def LLVM_UAddWithOverflowOp + : LLVM_IntrOp<"uadd.with.overflow", [0], [], [], 2> { + let arguments = (ins LLVM_Type, LLVM_Type); +} +def LLVM_SSubWithOverflowOp + : LLVM_IntrOp<"ssub.with.overflow", [0], [], [], 2> { + let arguments = (ins LLVM_Type, LLVM_Type); +} +def LLVM_USubWithOverflowOp + : LLVM_IntrOp<"usub.with.overflow", [0], [], [], 2> { + let arguments = (ins LLVM_Type, LLVM_Type); +} +def LLVM_SMulWithOverflowOp + : LLVM_IntrOp<"smul.with.overflow", [0], [], [], 2> { + let arguments = (ins LLVM_Type, LLVM_Type); +} +def LLVM_UMulWithOverflowOp + : LLVM_IntrOp<"umul.with.overflow", [0], [], [], 2> { + let arguments = (ins LLVM_Type, LLVM_Type); +} + // // Vector Reductions. // diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -40,9 +40,9 @@ class NVVM_IntrOp overloadedResults, list overloadedOperands, list traits, - bit hasResult> + int numResults> : LLVM_IntrOpBase; + overloadedResults, overloadedOperands, traits, numResults>; //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -293,6 +293,59 @@ llvm.return } +// CHECK-LABEL: @sadd_with_overflow_test +llvm.func @sadd_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) { + // CHECK: call { i32, i1 } @llvm.sadd.with.overflow.i32 + "llvm.intr.sadd.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)> + // CHECK: call { <8 x i32>, <8 x i1> } @llvm.sadd.with.overflow.v8i32 + "llvm.intr.sadd.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)> + llvm.return +} + +// CHECK-LABEL: @uadd_with_overflow_test +llvm.func @uadd_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) { + // CHECK: call { i32, i1 } @llvm.uadd.with.overflow.i32 + "llvm.intr.uadd.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)> + // CHECK: call { <8 x i32>, <8 x i1> } @llvm.uadd.with.overflow.v8i32 + "llvm.intr.uadd.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)> + llvm.return +} + +// CHECK-LABEL: @ssub_with_overflow_test +llvm.func @ssub_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) { + // CHECK: call { i32, i1 } @llvm.ssub.with.overflow.i32 + "llvm.intr.ssub.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)> + // CHECK: call { <8 x i32>, <8 x i1> } @llvm.ssub.with.overflow.v8i32 + "llvm.intr.ssub.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)> + llvm.return +} + +// CHECK-LABEL: @usub_with_overflow_test +llvm.func @usub_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) { + // CHECK: call { i32, i1 } @llvm.usub.with.overflow.i32 + "llvm.intr.usub.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)> + // CHECK: call { <8 x i32>, <8 x i1> } @llvm.usub.with.overflow.v8i32 + "llvm.intr.usub.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)> + llvm.return +} + +// CHECK-LABEL: @smul_with_overflow_test +llvm.func @smul_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) { + // CHECK: call { i32, i1 } @llvm.smul.with.overflow.i32 + "llvm.intr.smul.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)> + // CHECK: call { <8 x i32>, <8 x i1> } @llvm.smul.with.overflow.v8i32 + "llvm.intr.smul.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)> + llvm.return +} + +// CHECK-LABEL: @umul_with_overflow_test +llvm.func @umul_with_overflow_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm.vec<8 x i32>, %arg3: !llvm.vec<8 x i32>) { + // CHECK: call { i32, i1 } @llvm.umul.with.overflow.i32 + "llvm.intr.umul.with.overflow"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.struct<(i32, i1)> + // CHECK: call { <8 x i32>, <8 x i1> } @llvm.umul.with.overflow.v8i32 + "llvm.intr.umul.with.overflow"(%arg2, %arg3) : (!llvm.vec<8 x i32>, !llvm.vec<8 x i32>) -> !llvm.struct<(vec<8 x i32>, vec<8 x i1>)> + llvm.return +} // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) @@ -330,3 +383,13 @@ // CHECK-DAG: declare void @llvm.masked.compressstore.v7f32(<7 x float>, float*, <7 x i1>) // CHECK-DAG: declare void @llvm.memcpy.p0i8.p0i8.i32(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i32, i1 immarg) // CHECK-DAG: declare void @llvm.memcpy.inline.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64 immarg, i1 immarg) +// CHECK-DAG: declare { i32, i1 } @llvm.sadd.with.overflow.i32(i32, i32) +// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.sadd.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0 +// CHECK-DAG: declare { i32, i1 } @llvm.uadd.with.overflow.i32(i32, i32) +// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.uadd.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0 +// CHECK-DAG: declare { i32, i1 } @llvm.ssub.with.overflow.i32(i32, i32) +// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.ssub.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0 +// CHECK-DAG: declare { i32, i1 } @llvm.usub.with.overflow.i32(i32, i32) +// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.usub.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0 +// CHECK-DAG: declare { i32, i1 } @llvm.umul.with.overflow.i32(i32, i32) +// CHECK-DAG: declare { <8 x i32>, <8 x i1> } @llvm.umul.with.overflow.v8i32(<8 x i32>, <8 x i32>) #0 diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -210,7 +210,7 @@ printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os); os << ", "; printBracketedRange(traits, os); - os << ", " << (intr.getNumResults() == 0 ? 0 : 1) << ">, Arguments<(ins" + os << ", " << intr.getNumResults() << ">, Arguments<(ins" << (operands.empty() ? "" : " "); llvm::interleaveComma(operands, os); os << ")>;\n\n";