diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -11,96 +11,114 @@ // operations and intrinsics. However, we systematically prefix them with // "intr." to avoid potential name clashes. -class LLVM_UnaryIntrinsicOp traits = []> : +class LLVM_UnaryIntrOpBase traits = []> : LLVM_OneResultIntrOp { - let arguments = (ins LLVM_Type:$in); + let arguments = (ins LLVM_ScalarOrVectorOf:$in); } -class LLVM_BinarySameArgsIntrinsicOp traits = []> : +class LLVM_UnaryIntrOpI traits = []> : + LLVM_UnaryIntrOpBase; + +class LLVM_UnaryIntrOpF traits = []> : + LLVM_UnaryIntrOpBase; + +class LLVM_BinarySameArgsIntrOpBase traits = []> : LLVM_OneResultIntrOp { - let arguments = (ins LLVM_Type:$a, LLVM_Type:$b); + let arguments = (ins LLVM_ScalarOrVectorOf:$a, + LLVM_ScalarOrVectorOf:$b); } -class LLVM_BinaryIntrinsicOp traits = []> : - LLVM_OneResultIntrOp { - let arguments = (ins LLVM_Type:$a, LLVM_Type:$b); -} +class LLVM_BinarySameArgsIntrOpI traits = []> : + LLVM_BinarySameArgsIntrOpBase; -class LLVM_TernarySameArgsIntrinsicOp traits = []> : +class LLVM_BinarySameArgsIntrOpF traits = []> : + LLVM_BinarySameArgsIntrOpBase; + +class LLVM_TernarySameArgsIntrOpF traits = []> : LLVM_OneResultIntrOp { - let arguments = (ins LLVM_Type:$a, LLVM_Type:$b, LLVM_Type:$c); + let arguments = (ins LLVM_ScalarOrVectorOf:$a, + LLVM_ScalarOrVectorOf:$b, + LLVM_ScalarOrVectorOf:$c); } -class LLVM_CountZerosIntrinsicOp traits = []> : +class LLVM_CountZerosIntrOp traits = []> : LLVM_OneResultIntrOp { - let arguments = (ins LLVM_Type:$in, I1:$zero_undefined); + let arguments = (ins LLVM_ScalarOrVectorOf:$in, + I1:$zero_undefined); } def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure]> { - let arguments = (ins LLVM_Type:$in, I1:$is_int_min_poison); + let arguments = (ins LLVM_ScalarOrVectorOf:$in, + I1:$is_int_min_poison); } def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure]> { - let arguments = (ins LLVM_Type:$in, I32:$bit); -} - -def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">; -def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">; -def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">; -def LLVM_Exp2Op : LLVM_UnaryIntrinsicOp<"exp2">; -def LLVM_FAbsOp : LLVM_UnaryIntrinsicOp<"fabs">; -def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">; -def LLVM_FFloorOp : LLVM_UnaryIntrinsicOp<"floor">; -def LLVM_FMAOp : LLVM_TernarySameArgsIntrinsicOp<"fma">; -def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">; -def LLVM_Log10Op : LLVM_UnaryIntrinsicOp<"log10">; -def LLVM_Log2Op : LLVM_UnaryIntrinsicOp<"log2">; -def LLVM_LogOp : LLVM_UnaryIntrinsicOp<"log">; + let arguments = (ins LLVM_ScalarOrVectorOf:$in, I32:$bit); +} + +def LLVM_CopySignOp : LLVM_BinarySameArgsIntrOpF<"copysign">; +def LLVM_CosOp : LLVM_UnaryIntrOpF<"cos">; +def LLVM_ExpOp : LLVM_UnaryIntrOpF<"exp">; +def LLVM_Exp2Op : LLVM_UnaryIntrOpF<"exp2">; +def LLVM_FAbsOp : LLVM_UnaryIntrOpF<"fabs">; +def LLVM_FCeilOp : LLVM_UnaryIntrOpF<"ceil">; +def LLVM_FFloorOp : LLVM_UnaryIntrOpF<"floor">; +def LLVM_FMAOp : LLVM_TernarySameArgsIntrOpF<"fma">; +def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrOpF<"fmuladd">; +def LLVM_Log10Op : LLVM_UnaryIntrOpF<"log10">; +def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">; +def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">; def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0]> { - let arguments = (ins LLVM_Type:$addr, LLVM_Type:$rw, LLVM_Type:$hint, - LLVM_Type:$cache); -} -def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">; -def LLVM_RoundEvenOp : LLVM_UnaryIntrinsicOp<"roundeven">; -def LLVM_RoundOp : LLVM_UnaryIntrinsicOp<"round">; -def LLVM_FTruncOp : LLVM_UnaryIntrinsicOp<"trunc">; -def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">; -def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">; -def LLVM_PowIOp : LLVM_BinaryIntrinsicOp<"powi">; -def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">; -def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrinsicOp<"ctlz">; -def LLVM_CountTrailingZerosOp : LLVM_CountZerosIntrinsicOp<"cttz">; -def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">; -def LLVM_MaxNumOp : LLVM_BinarySameArgsIntrinsicOp<"maxnum">; -def LLVM_MinNumOp : LLVM_BinarySameArgsIntrinsicOp<"minnum">; -def LLVM_MaximumOp : LLVM_BinarySameArgsIntrinsicOp<"maximum">; -def LLVM_MinimumOp : LLVM_BinarySameArgsIntrinsicOp<"minimum">; -def LLVM_SMaxOp : LLVM_BinarySameArgsIntrinsicOp<"smax">; -def LLVM_SMinOp : LLVM_BinarySameArgsIntrinsicOp<"smin">; -def LLVM_UMaxOp : LLVM_BinarySameArgsIntrinsicOp<"umax">; -def LLVM_UMinOp : LLVM_BinarySameArgsIntrinsicOp<"umin">; + let arguments = (ins LLVM_AnyPointer:$addr, I32:$rw, I32:$hint, I32:$cache); +} +def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">; +def LLVM_RoundEvenOp : LLVM_UnaryIntrOpF<"roundeven">; +def LLVM_RoundOp : LLVM_UnaryIntrOpF<"round">; +def LLVM_FTruncOp : LLVM_UnaryIntrOpF<"trunc">; +def LLVM_SqrtOp : LLVM_UnaryIntrOpF<"sqrt">; +def LLVM_PowOp : LLVM_BinarySameArgsIntrOpF<"pow">; +def LLVM_PowIOp : LLVM_OneResultIntrOp<"powi"> { + let arguments = (ins LLVM_ScalarOrVectorOf:$val, + AnySignlessInteger:$power); +} +def LLVM_BitReverseOp : LLVM_UnaryIntrOpI<"bitreverse">; +def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrOp<"ctlz">; +def LLVM_CountTrailingZerosOp : LLVM_CountZerosIntrOp<"cttz">; +def LLVM_CtPopOp : LLVM_UnaryIntrOpI<"ctpop">; +def LLVM_MaxNumOp : LLVM_BinarySameArgsIntrOpF<"maxnum">; +def LLVM_MinNumOp : LLVM_BinarySameArgsIntrOpF<"minnum">; +def LLVM_MaximumOp : LLVM_BinarySameArgsIntrOpF<"maximum">; +def LLVM_MinimumOp : LLVM_BinarySameArgsIntrOpF<"minimum">; +def LLVM_SMaxOp : LLVM_BinarySameArgsIntrOpI<"smax">; +def LLVM_SMinOp : LLVM_BinarySameArgsIntrOpI<"smin">; +def LLVM_UMaxOp : LLVM_BinarySameArgsIntrOpI<"umax">; +def LLVM_UMinOp : LLVM_BinarySameArgsIntrOpI<"umin">; def LLVM_MemcpyOp : LLVM_ZeroResultIntrOp<"memcpy", [0, 1, 2]> { - let arguments = (ins Arg:$dst, Arg:$src, LLVM_Type:$len, - LLVM_Type:$isVolatile); + let arguments = (ins Arg:$dst, + Arg:$src, + AnySignlessInteger:$len, I1:$isVolatile); } def LLVM_MemcpyInlineOp : LLVM_ZeroResultIntrOp<"memcpy.inline", [0, 1, 2]> { - let arguments = (ins Arg:$dst, Arg:$src, LLVM_Type:$len, - LLVM_Type:$isVolatile); + let arguments = (ins Arg:$dst, + Arg:$src, + AnySignlessInteger:$len, I1:$isVolatile); } def LLVM_MemmoveOp : LLVM_ZeroResultIntrOp<"memmove", [0, 1, 2]> { - let arguments = (ins Arg:$dst, Arg:$src, LLVM_Type:$len, - LLVM_Type:$isVolatile); + let arguments = (ins Arg:$dst, + Arg:$src, + AnySignlessInteger:$len, I1:$isVolatile); } def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2]> { - let arguments = (ins Arg:$dst, LLVM_Type:$val, LLVM_Type:$len, - LLVM_Type:$isVolatile); + let arguments = (ins Arg:$dst, + I8:$val, AnySignlessInteger:$len, I1:$isVolatile); } //===----------------------------------------------------------------------===// @@ -143,36 +161,20 @@ // Intrinsics with multiple returns. -def LLVM_SAddWithOverflowOp - : LLVM_IntrOp<"sadd.with.overflow", [0], [], [], 2> { - let arguments = (ins LLVM_Type, LLVM_Type); -} -def LLVM_UAddWithOverflowOp - : LLVM_IntrOp<"uadd.with.overflow", [0], [], [], 2> { - let arguments = (ins LLVM_Type, LLVM_Type); -} -def LLVM_SSubWithOverflowOp - : LLVM_IntrOp<"ssub.with.overflow", [0], [], [], 2> { - let arguments = (ins LLVM_Type, LLVM_Type); -} -def LLVM_USubWithOverflowOp - : LLVM_IntrOp<"usub.with.overflow", [0], [], [], 2> { - let arguments = (ins LLVM_Type, LLVM_Type); -} -def LLVM_SMulWithOverflowOp - : LLVM_IntrOp<"smul.with.overflow", [0], [], [], 2> { - let arguments = (ins LLVM_Type, LLVM_Type); -} -def LLVM_UMulWithOverflowOp - : LLVM_IntrOp<"umul.with.overflow", [0], [], [], 2> { - let arguments = (ins LLVM_Type, LLVM_Type); -} +class LLVM_ArithWithOverflowOp + : LLVM_IntrOp, + Arguments<(ins LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf)>; +def LLVM_SAddWithOverflowOp : LLVM_ArithWithOverflowOp<"sadd.with.overflow">; +def LLVM_UAddWithOverflowOp : LLVM_ArithWithOverflowOp<"uadd.with.overflow">; +def LLVM_SSubWithOverflowOp : LLVM_ArithWithOverflowOp<"ssub.with.overflow">; +def LLVM_USubWithOverflowOp : LLVM_ArithWithOverflowOp<"usub.with.overflow">; +def LLVM_SMulWithOverflowOp : LLVM_ArithWithOverflowOp<"smul.with.overflow">; +def LLVM_UMulWithOverflowOp : LLVM_ArithWithOverflowOp<"umul.with.overflow">; -def LLVM_AssumeOp : LLVM_ZeroResultIntrOp<"assume", []> { - let arguments = (ins LLVM_Type:$cond); -} - +def LLVM_AssumeOp + : LLVM_ZeroResultIntrOp<"assume", []>, Arguments<(ins I1:$cond)>; // // Coroutine intrinsics. @@ -318,20 +320,66 @@ // Vector Reductions. // -def LLVM_vector_reduce_add : LLVM_VectorReduction<"add">; -def LLVM_vector_reduce_and : LLVM_VectorReduction<"and">; -def LLVM_vector_reduce_mul : LLVM_VectorReduction<"mul">; -def LLVM_vector_reduce_fmax : LLVM_VectorReduction<"fmax">; -def LLVM_vector_reduce_fmin : LLVM_VectorReduction<"fmin">; -def LLVM_vector_reduce_or : LLVM_VectorReduction<"or">; -def LLVM_vector_reduce_smax : LLVM_VectorReduction<"smax">; -def LLVM_vector_reduce_smin : LLVM_VectorReduction<"smin">; -def LLVM_vector_reduce_umax : LLVM_VectorReduction<"umax">; -def LLVM_vector_reduce_umin : LLVM_VectorReduction<"umin">; -def LLVM_vector_reduce_xor : LLVM_VectorReduction<"xor">; - -def LLVM_vector_reduce_fadd : LLVM_VectorReductionAcc<"fadd">; -def LLVM_vector_reduce_fmul : LLVM_VectorReductionAcc<"fmul">; +// LLVM vector reduction over a single vector. +class LLVM_VecReductionBase + : LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0], + [Pure, SameOperandsAndResultElementType]>, + Arguments<(ins LLVM_VectorOf)>; + +class LLVM_VecReductionF + : LLVM_VecReductionBase; + +class LLVM_VecReductionI + : LLVM_VecReductionBase; + +// LLVM vector reduction over a single vector, with an initial value, +// and with permission to reassociate the reduction operations. +class LLVM_VecReductionAccBase + : LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0], + [Pure, SameOperandsAndResultElementType]>, + Arguments<(ins element:$start_value, LLVM_VectorOf:$input, + DefaultValuedAttr:$reassoc)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, + llvm::Intrinsic::vector_reduce_}] # mnem # [{, + { }] # !interleave(ListIntSubst.lst, + ", ") # [{ + }); + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + llvm::FastMathFlags origFM = builder.getFastMathFlags(); + llvm::FastMathFlags tempFM = origFM; + tempFM.setAllowReassoc($reassoc); + builder.setFastMathFlags(tempFM); // set fastmath flag + $res = builder.CreateCall(fn, operands); + builder.setFastMathFlags(origFM); // restore fastmath flag + }]; + let mlirBuilder = [{ + bool allowReassoc = inst->getFastMathFlags().allowReassoc(); + $res = $_builder.create<$_qualCppClassName>($_location, + $_resultType, $start_value, $input, allowReassoc); + }]; +} + +class LLVM_VecReductionAccF + : LLVM_VecReductionAccBase; + +def LLVM_vector_reduce_add : LLVM_VecReductionI<"add">; +def LLVM_vector_reduce_and : LLVM_VecReductionI<"and">; +def LLVM_vector_reduce_mul : LLVM_VecReductionI<"mul">; +def LLVM_vector_reduce_or : LLVM_VecReductionI<"or">; +def LLVM_vector_reduce_smax : LLVM_VecReductionI<"smax">; +def LLVM_vector_reduce_smin : LLVM_VecReductionI<"smin">; +def LLVM_vector_reduce_umax : LLVM_VecReductionI<"umax">; +def LLVM_vector_reduce_umin : LLVM_VecReductionI<"umin">; +def LLVM_vector_reduce_xor : LLVM_VecReductionI<"xor">; + +def LLVM_vector_reduce_fmax : LLVM_VecReductionF<"fmax">; +def LLVM_vector_reduce_fmin : LLVM_VecReductionF<"fmin">; + +def LLVM_vector_reduce_fadd : LLVM_VecReductionAccF<"fadd">; +def LLVM_vector_reduce_fmul : LLVM_VecReductionAccF<"fmul">; // // LLVM Matrix operations. @@ -345,12 +393,12 @@ /// columns - Number of columns in matrix (must be a constant) /// stride - Space between columns def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.load"> { - let arguments = (ins LLVM_Type:$data, LLVM_Type:$stride, I1Attr:$isVolatile, + let arguments = (ins LLVM_AnyPointer:$data, AnySignlessInteger:$stride, I1Attr:$isVolatile, I32Attr:$rows, I32Attr:$columns); let results = (outs LLVM_AnyVector:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$data `,` `<` `stride` `=` $stride `>` attr-dict" - "`:` type($res) `from` type($data) `stride` type($stride)"; + "`:` type($res) `from` qualified(type($data)) `stride` type($stride)"; string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); @@ -379,12 +427,12 @@ /// columns - Number of columns in matrix (must be a constant) /// stride - Space between columns def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.store"> { - let arguments = (ins LLVM_AnyVector:$matrix, LLVM_Type:$data, - LLVM_Type:$stride, I1Attr:$isVolatile, I32Attr:$rows, + let arguments = (ins LLVM_AnyVector:$matrix, LLVM_AnyPointer:$data, + AnySignlessInteger:$stride, I1Attr:$isVolatile, I32Attr:$rows, I32Attr:$columns); let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder]; let assemblyFormat = "$matrix `,` $data `,` `<` `stride` `=` $stride `>` " - "attr-dict`:` type($matrix) `to` type($data) `stride` type($stride)"; + "attr-dict`:` type($matrix) `to` qualified(type($data)) `stride` type($stride)"; string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); @@ -407,9 +455,9 @@ /// Create a llvm.matrix.multiply call, multiplying 2-D matrices LHS and RHS, as /// specified in the LLVM MatrixBuilder. def LLVM_MatrixMultiplyOp : LLVM_OneResultIntrOp<"matrix.multiply"> { - let arguments = (ins LLVM_Type:$lhs, LLVM_Type:$rhs, I32Attr:$lhs_rows, + let arguments = (ins LLVM_AnyVector:$lhs, LLVM_AnyVector:$rhs, I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns); - let results = (outs LLVM_Type:$res); + let results = (outs LLVM_AnyVector:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$lhs `,` $rhs attr-dict " "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; @@ -430,8 +478,8 @@ /// Create a llvm.matrix.transpose call, transposing a `rows` x `columns` 2-D /// `matrix`, as specified in the LLVM MatrixBuilder. def LLVM_MatrixTransposeOp : LLVM_OneResultIntrOp<"matrix.transpose"> { - let arguments = (ins LLVM_Type:$matrix, I32Attr:$rows, I32Attr:$columns); - let results = (outs LLVM_Type:$res); + let arguments = (ins LLVM_AnyVector:$matrix, I32Attr:$rows, I32Attr:$columns); + let results = (outs LLVM_AnyVector:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$matrix attr-dict `:` type($matrix) `into` type($res)"; @@ -454,15 +502,15 @@ /// Create a llvm.get.active.lane.mask to set a mask up to a given position. def LLVM_GetActiveLaneMaskOp : LLVM_OneResultIntrOp<"get.active.lane.mask", [0], [0], [Pure]> { - let arguments = (ins LLVM_Type:$base, LLVM_Type:$n); + let arguments = (ins AnySignlessInteger:$base, AnySignlessInteger:$n); let assemblyFormat = "$base `,` $n attr-dict `:` " "type($base) `,` type($n) `to` type($res)"; } /// Create a call to Masked Load intrinsic. def LLVM_MaskedLoadOp : LLVM_OneResultIntrOp<"masked.load"> { - let arguments = (ins LLVM_Type:$data, LLVM_Type:$mask, - Variadic:$pass_thru, I32Attr:$alignment); + let arguments = (ins LLVM_AnyPointer:$data, LLVM_VectorOf:$mask, + Variadic:$pass_thru, I32Attr:$alignment); let results = (outs LLVM_AnyVector:$res); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -482,11 +530,11 @@ /// Create a call to Masked Store intrinsic. def LLVM_MaskedStoreOp : LLVM_ZeroResultIntrOp<"masked.store"> { - let arguments = (ins LLVM_Type:$value, LLVM_Type:$data, LLVM_Type:$mask, - I32Attr:$alignment); + let arguments = (ins LLVM_AnyVector:$value, LLVM_AnyPointer:$data, + LLVM_VectorOf:$mask, I32Attr:$alignment); let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder]; let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` " - "type($value) `,` type($mask) `into` type($data)"; + "type($value) `,` type($mask) `into` qualified(type($data))"; string llvmBuilder = [{ builder.CreateMaskedStore( @@ -501,9 +549,10 @@ /// Create a call to Masked Gather intrinsic. def LLVM_masked_gather : LLVM_OneResultIntrOp<"masked.gather"> { - let arguments = (ins LLVM_AnyVector:$ptrs, LLVM_Type:$mask, - Variadic:$pass_thru, I32Attr:$alignment); - let results = (outs LLVM_Type:$res); + let arguments = (ins LLVM_VectorOf:$ptrs, + LLVM_VectorOf:$mask, Variadic:$pass_thru, + I32Attr:$alignment); + let results = (outs LLVM_AnyVector:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -523,8 +572,8 @@ /// Create a call to Masked Scatter intrinsic. def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> { - let arguments = (ins LLVM_Type:$value, LLVM_Type:$ptrs, LLVM_Type:$mask, - I32Attr:$alignment); + let arguments = (ins LLVM_AnyVector:$value, LLVM_VectorOf:$ptrs, + LLVM_VectorOf:$mask, I32Attr:$alignment); let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder]; let assemblyFormat = "$value `,` $ptrs `,` $mask attr-dict `:` " "type($value) `,` type($mask) `into` type($ptrs)"; @@ -542,13 +591,13 @@ /// Create a call to Masked Expand Load intrinsic. def LLVM_masked_expandload : LLVM_IntrOp<"masked.expandload", [0], [], [], 1> { - let arguments = (ins LLVM_Type, LLVM_Type, LLVM_Type); + let arguments = (ins LLVM_AnyPointer, LLVM_VectorOf, LLVM_AnyVector); } /// Create a call to Masked Compress Store intrinsic. def LLVM_masked_compressstore : LLVM_IntrOp<"masked.compressstore", [], [0], [], 0> { - let arguments = (ins LLVM_Type, LLVM_Type, LLVM_Type); + let arguments = (ins LLVM_AnyVector, LLVM_AnyPointer, LLVM_VectorOf); } /// Create a call to vscale intrinsic. @@ -558,7 +607,7 @@ def LLVM_StepVectorOp : LLVM_IntrOp<"experimental.stepvector", [0], [], [Pure], 1> { let arguments = (ins); - let results = (outs LLVM_Type:$res); + let results = (outs LLVM_VectorOf:$res); let assemblyFormat = "attr-dict `:` type($res)"; } @@ -640,7 +689,7 @@ Arguments<(ins LLVM_VectorOf:$lhs, LLVM_VectorOf:$rhs, LLVM_VectorOf:$mask, I32:$evl)>; -class LLVM_VPBinaryI : LLVM_VPBinaryBase; +class LLVM_VPBinaryI : LLVM_VPBinaryBase; class LLVM_VPBinaryF : LLVM_VPBinaryBase; @@ -664,7 +713,7 @@ Arguments<(ins element:$satrt_value, LLVM_VectorOf:$val, LLVM_VectorOf:$mask, I32:$evl)>; -class LLVM_VPReductionI : LLVM_VPReductionBase; +class LLVM_VPReductionI : LLVM_VPReductionBase; class LLVM_VPReductionF : LLVM_VPReductionBase; @@ -678,7 +727,7 @@ Arguments<(ins LLVM_VectorOf:$src, LLVM_VectorOf:$mask, I32:$evl)>; -class LLVM_VPCastI : LLVM_VPCastBase; +class LLVM_VPCastI : LLVM_VPCastBase; class LLVM_VPCastF : LLVM_VPCastBase; @@ -747,13 +796,13 @@ // Strided load/store def LLVM_VPStridedLoadOp : LLVM_OneResultIntrOp<"experimental.vp.strided.load", [0], [0, 1], []>, - Arguments<(ins LLVM_AnyPointer:$ptr, AnyInteger:$stride, + Arguments<(ins LLVM_AnyPointer:$ptr, AnySignlessInteger:$stride, LLVM_VectorOf:$mask, I32:$evl)>; def LLVM_VPStridedStoreOp : LLVM_ZeroResultIntrOp<"experimental.vp.strided.store",[0, 1, 2], []>, Arguments<(ins LLVM_AnyVector:$val, LLVM_AnyPointer:$ptr, - AnyInteger:$stride, LLVM_VectorOf:$mask, I32:$evl)>; + AnySignlessInteger:$stride, LLVM_VectorOf:$mask, I32:$evl)>; def LLVM_VPTruncOp : LLVM_VPCastI<"trunc">; def LLVM_VPZExtOp : LLVM_VPCastI<"zext">; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -417,42 +417,6 @@ list traits = []> : LLVM_IntrOp; -// LLVM vector reduction over a single vector. -class LLVM_VectorReduction - : LLVM_OneResultIntrOp<"vector.reduce." # mnem, - [], [0], [Pure]>, - Arguments<(ins LLVM_Type)>; - -// LLVM vector reduction over a single vector, with an initial value, -// and with permission to reassociate the reduction operations. -class LLVM_VectorReductionAcc - : LLVM_OneResultIntrOp<"vector.reduce." # mnem, - [], [0], [Pure]>, - Arguments<(ins LLVM_Type:$start_value, LLVM_Type:$input, - DefaultValuedAttr:$reassoc)> { - let llvmBuilder = [{ - llvm::Module *module = builder.GetInsertBlock()->getModule(); - llvm::Function *fn = llvm::Intrinsic::getDeclaration( - module, - llvm::Intrinsic::vector_reduce_}] # mnem # [{, - { }] # !interleave(ListIntSubst.lst, - ", ") # [{ - }); - auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - llvm::FastMathFlags origFM = builder.getFastMathFlags(); - llvm::FastMathFlags tempFM = origFM; - tempFM.setAllowReassoc($reassoc); - builder.setFastMathFlags(tempFM); // set fastmath flag - $res = builder.CreateCall(fn, operands); - builder.setFastMathFlags(origFM); // restore fastmath flag - }]; - let mlirBuilder = [{ - bool allowReassoc = inst->getFastMathFlags().allowReassoc(); - $res = $_builder.create<$_qualCppClassName>($_location, - $_resultType, $start_value, $input, allowReassoc); - }]; -} - def LLVM_OneResultOpBuilder : OpBuilder<(ins "Type":$resultType, "ValueRange":$operands, CArg<"ArrayRef", "{}">:$attributes), diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -1,82 +1,82 @@ // RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s -// expected-error @+1 {{cannot be converted to LLVM IR}} +// expected-error @below{{cannot be converted to LLVM IR}} func.func @foo() { llvm.return } // ----- -// expected-error @+1 {{llvm.noalias attribute attached to LLVM non-pointer argument}} +// expected-error @below{{llvm.noalias attribute attached to LLVM non-pointer argument}} llvm.func @invalid_noalias(%arg0 : f32 {llvm.noalias}) -> f32 { llvm.return %arg0 : f32 } // ----- -// expected-error @+1 {{llvm.sret attribute attached to LLVM non-pointer argument}} +// expected-error @below{{llvm.sret attribute attached to LLVM non-pointer argument}} llvm.func @invalid_sret(%arg0 : f32 {llvm.sret = f32}) -> f32 { llvm.return %arg0 : f32 } // ----- -// expected-error @+1 {{llvm.sret attribute attached to LLVM pointer argument of a different type}} +// expected-error @below{{llvm.sret attribute attached to LLVM pointer argument of a different type}} llvm.func @invalid_sret(%arg0 : !llvm.ptr {llvm.sret = i32}) -> !llvm.ptr { llvm.return %arg0 : !llvm.ptr } // ----- -// expected-error @+1 {{llvm.nest attribute attached to LLVM non-pointer argument}} +// expected-error @below{{llvm.nest attribute attached to LLVM non-pointer argument}} llvm.func @invalid_nest(%arg0 : f32 {llvm.nest}) -> f32 { llvm.return %arg0 : f32 } // ----- -// expected-error @+1 {{llvm.byval attribute attached to LLVM non-pointer argument}} +// expected-error @below{{llvm.byval attribute attached to LLVM non-pointer argument}} llvm.func @invalid_byval(%arg0 : f32 {llvm.byval = f32}) -> f32 { llvm.return %arg0 : f32 } // ----- -// expected-error @+1 {{llvm.byval attribute attached to LLVM pointer argument of a different type}} +// expected-error @below{{llvm.byval attribute attached to LLVM pointer argument of a different type}} llvm.func @invalid_sret(%arg0 : !llvm.ptr {llvm.byval = i32}) -> !llvm.ptr { llvm.return %arg0 : !llvm.ptr } // ----- -// expected-error @+1 {{llvm.byref attribute attached to LLVM non-pointer argument}} +// expected-error @below{{llvm.byref attribute attached to LLVM non-pointer argument}} llvm.func @invalid_byval(%arg0 : f32 {llvm.byref = f32}) -> f32 { llvm.return %arg0 : f32 } // ----- -// expected-error @+1 {{llvm.byref attribute attached to LLVM pointer argument of a different type}} +// expected-error @below{{llvm.byref attribute attached to LLVM pointer argument of a different type}} llvm.func @invalid_sret(%arg0 : !llvm.ptr {llvm.byref = i32}) -> !llvm.ptr { llvm.return %arg0 : !llvm.ptr } // ----- -// expected-error @+1 {{llvm.inalloca attribute attached to LLVM non-pointer argument}} +// expected-error @below{{llvm.inalloca attribute attached to LLVM non-pointer argument}} llvm.func @invalid_byval(%arg0 : f32 {llvm.inalloca = f32}) -> f32 { llvm.return %arg0 : f32 } // ----- -// expected-error @+1 {{llvm.inalloca attribute attached to LLVM pointer argument of a different type}} +// expected-error @below{{llvm.inalloca attribute attached to LLVM pointer argument of a different type}} llvm.func @invalid_sret(%arg0 : !llvm.ptr {llvm.inalloca = i32}) -> !llvm.ptr { llvm.return %arg0 : !llvm.ptr } // ----- -// expected-error @+1 {{llvm.align attribute attached to LLVM non-pointer argument}} +// expected-error @below{{llvm.align attribute attached to LLVM non-pointer argument}} llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 { llvm.return %arg0 : f32 } @@ -84,7 +84,7 @@ // ----- llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { - // expected-error @+1 {{expected struct type to be a complex number}} + // expected-error @below{{expected struct type to be a complex number}} %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> } @@ -92,7 +92,7 @@ // ----- llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> { - // expected-error @+1 {{expected struct type to be a complex number}} + // expected-error @below{{expected struct type to be a complex number}} %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> } @@ -100,34 +100,237 @@ // ----- llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> { - // expected-error @+1 {{FloatAttr does not match expected type of the constant}} + // expected-error @below{{FloatAttr does not match expected type of the constant}} %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)> llvm.return %0 : !llvm.struct<(f64, f64)> } // ----- -// expected-error @+1 {{unsupported constant value}} +// expected-error @below{{unsupported constant value}} llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64> // ----- -// expected-error @+1 {{LLVM attribute 'noinline' does not expect a value}} +// expected-error @below{{LLVM attribute 'noinline' does not expect a value}} llvm.func @passthrough_unexpected_value() attributes {passthrough = [["noinline", "42"]]} // ----- -// expected-error @+1 {{LLVM attribute 'alignstack' expects a value}} +// expected-error @below{{LLVM attribute 'alignstack' expects a value}} llvm.func @passthrough_expected_value() attributes {passthrough = ["alignstack"]} // ----- -// expected-error @+1 {{expected 'passthrough' to contain string or array attributes}} +// expected-error @below{{expected 'passthrough' to contain string or array attributes}} llvm.func @passthrough_wrong_type() attributes {passthrough = [42]} // ----- -// expected-error @+1 {{expected arrays within 'passthrough' to contain two strings}} +// expected-error @below{{expected arrays within 'passthrough' to contain two strings}} llvm.func @passthrough_wrong_type() attributes { passthrough = [[ 42, 42 ]] } + +// ----- + +llvm.func @unary_float_intr_wrong_type(%arg0 : i32) -> i32 { + // expected-error @below{{op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type}} + %0 = "llvm.intr.exp"(%arg0) : (i32) -> i32 + llvm.return %0 : i32 +} + +// ----- + +llvm.func @binary_float_intr_wrong_type(%arg0 : f32, %arg1 : i32) -> i32 { + // expected-error @below{{op operand #1 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type}} + %0 = "llvm.intr.pow"(%arg0, %arg1) : (f32, i32) -> i32 + llvm.return %0 : i32 +} + +// ----- + +llvm.func @unary_int_intr_wrong_type(%arg0 : f32) -> f32 { + // expected-error @below{{op operand #0 must be signless integer or LLVM dialect-compatible vector of signless integer}} + %0 = "llvm.intr.ctpop"(%arg0) : (f32) -> f32 + llvm.return %0 : f32 +} + +// ----- + +llvm.func @binary_int_intr_wrong_type(%arg0 : i32, %arg1 : f32) -> f32 { + // expected-error @below{{op operand #1 must be signless integer or LLVM dialect-compatible vector of signless integer}} + %0 = "llvm.intr.smax"(%arg0, %arg1) : (i32, f32) -> f32 + llvm.return %0 : f32 +} + +// ----- + +llvm.func @ternary_float_intr_wrong_type(%arg0 : f32, %arg1 : f32, %arg2 : i32) -> f32 { + // expected-error @below{{op operand #2 must be floating-point or LLVM dialect-compatible vector of floating-point}} + %0 = "llvm.intr.fma"(%arg0, %arg1, %arg2) : (f32, f32, i32) -> f32 + llvm.return %0 : f32 +} + +// ----- + +llvm.func @powi_intr_wrong_type(%arg0 : f32, %arg1 : f32) -> f32 { + // expected-error @below{{op operand #1 must be signless integer, but got 'f32'}} + %0 = "llvm.intr.powi"(%arg0, %arg1) : (f32, f32) -> f32 + llvm.return %0 : f32 +} + +// ----- + +llvm.func @ctlz_intr_wrong_type(%arg0 : i32, %arg1 : i32) -> i32 { + // expected-error @below{{op operand #1 must be 1-bit signless integer, but got 'i32'}} + %0 = "llvm.intr.ctlz"(%arg0, %arg1) : (i32, i32) -> i32 + llvm.return %0 : i32 +} + +// ----- + +llvm.func @memcpy_intr_wrong_type(%src : i64, %dst : i64, %len : i64, %volatile : i1) { + // expected-error @below{{op operand #0 must be LLVM pointer type, but got 'i64'}} + "llvm.intr.memcpy"(%src, %dst, %len, %volatile) : (i64, i64, i64, i1) -> () + llvm.return +} + +// ----- + +llvm.func @memcpy_inline_intr_wrong_type(%src : !llvm.ptr, %dst : !llvm.ptr, %len : i64, %volatile : i32) { + // expected-error @below{{op operand #3 must be 1-bit signless integer, but got 'i32'}} + "llvm.intr.memcpy.inline"(%src, %dst, %len, %volatile) : (!llvm.ptr, !llvm.ptr, i64, i32) -> () + llvm.return +} + +// ----- + +llvm.func @memmove_intr_wrong_type(%src : !llvm.ptr, %dst : i64, %len : i64, %volatile : i1) { + // expected-error @below{{op operand #1 must be LLVM pointer type, but got 'i64'}} + "llvm.intr.memmove"(%src, %dst, %len, %volatile) : (!llvm.ptr, i64, i64, i1) -> () + llvm.return +} + +// ----- + +llvm.func @memset_intr_wrong_type(%dst : !llvm.ptr, %val : i32, %len : i64, %volatile : i1) { + // expected-error @below{{op operand #1 must be 8-bit signless integer, but got 'i32'}} + "llvm.intr.memset"(%dst, %val, %len, %volatile) : (!llvm.ptr, i32, i64, i1) -> () + llvm.return +} + +// ----- + +llvm.func @sadd_overflow_intr_wrong_type(%arg0 : i32, %arg1 : f32) -> !llvm.struct<(i32, i1)> { + // expected-error @below{{op operand #1 must be signless integer or LLVM dialect-compatible vector of signless integer, but got 'f32'}} + %0 = "llvm.intr.sadd.with.overflow"(%arg0, %arg1) : (i32, f32) -> !llvm.struct<(i32, i1)> + llvm.return %0 : !llvm.struct<(i32, i1)> +} + +// ----- + +llvm.func @assume_intr_wrong_type(%cond : i16) { + // expected-error @below{{op operand #0 must be 1-bit signless integer, but got 'i16'}} + "llvm.intr.assume"(%cond) : (i16) -> () + llvm.return +} + +// ----- + +llvm.func @vec_reduce_add_intr_wrong_type(%arg0 : vector<4xi32>) -> f32 { + // expected-error @below{{op requires the same element type for all operands and results}} + %0 = "llvm.intr.vector.reduce.add"(%arg0) : (vector<4xi32>) -> f32 + llvm.return %0 : f32 +} + +// ----- + +llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 { + // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector of floating-point}} + %0 = "llvm.intr.vector.reduce.fmax"(%arg0) : (vector<4xi32>) -> i32 + llvm.return %0 : i32 +} + +// ----- + +llvm.func @matrix_load_intr_wrong_type(%ptr : !llvm.ptr, %stride : i32) -> f32 { + // expected-error @below{{op result #0 must be LLVM dialect-compatible vector type, but got 'f32'}} + %0 = llvm.intr.matrix.column.major.load %ptr, + { isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : f32 from !llvm.ptr stride i32 + llvm.return %0 : f32 +} + +// ----- + +llvm.func @matrix_store_intr_wrong_type(%matrix : vector<48xf32>, %ptr : i32, %stride : i64) { + // expected-error @below {{op operand #1 must be LLVM pointer type, but got 'i32'}} + llvm.intr.matrix.column.major.store %matrix, %ptr, + { isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : vector<48xf32> to i32 stride i64 + llvm.return +} + +// ----- + +llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32) -> vector<12xf32> { + // expected-error @below{{op operand #1 must be LLVM dialect-compatible vector type, but got 'f32'}} + %0 = llvm.intr.matrix.multiply %arg0, %arg1 + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (vector<64xf32>, f32) -> vector<12xf32> + llvm.return %0 : vector<12xf32> +} + +// ----- + +llvm.func @matrix_transpose_intr_wrong_type(%matrix : f32) -> vector<48xf32> { + // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}} + %0 = llvm.intr.matrix.transpose %matrix {rows = 3: i32, columns = 16: i32} : f32 into vector<48xf32> + llvm.return %0 : vector<48xf32> +} + +// ----- + +llvm.func @active_lane_intr_wrong_type(%base : i64, %n : vector<7xi64>) -> vector<7xi1> { + // expected-error @below{{invalid kind of type specified}} + %0 = llvm.intr.get.active.lane.mask %base, %n : i64, vector<7xi64> to vector<7xi1> + llvm.return %0 : vector<7xi1> +} + +// ----- + +llvm.func @masked_load_intr_wrong_type(%ptr : i64, %mask : vector<7xi1>) -> vector<7xf32> { + // expected-error @below{{op operand #0 must be LLVM pointer type, but got 'i64'}} + %0 = llvm.intr.masked.load %ptr, %mask { alignment = 1: i32} : (i64, vector<7xi1>) -> vector<7xf32> + llvm.return %0 : vector<7xf32> +} + +// ----- + +llvm.func @masked_store_intr_wrong_type(%vec : vector<7xf32>, %ptr : !llvm.ptr, %mask : vector<7xi32>) { + // expected-error @below{{op operand #2 must be LLVM dialect-compatible vector of 1-bit signless integer, but got 'vector<7xi32>}} + llvm.intr.masked.store %vec, %ptr, %mask { alignment = 1: i32} : vector<7xf32>, vector<7xi32> into !llvm.ptr + llvm.return +} + +// ----- + +llvm.func @masked_gather_intr_wrong_type(%ptrs : vector<7xf32>, %mask : vector<7xi1>) -> vector<7xf32> { + // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector of LLVM pointer type, but got 'vector<7xf32>'}} + %0 = llvm.intr.masked.gather %ptrs, %mask { alignment = 1: i32} : (vector<7xf32>, vector<7xi1>) -> vector<7xf32> + llvm.return %0 : vector<7xf32> +} + +// ----- + +llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : !llvm.vec<7xptr>, %mask : vector<7xi1>) { + // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}} + llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into !llvm.vec<7xptr> + llvm.return +} + +// ----- + +llvm.func @stepvector_intr_wrong_type() -> vector<7xf32> { + // expected-error @below{{op result #0 must be LLVM dialect-compatible vector of signless integer, but got 'vector<7xf32>'}} + %0 = llvm.intr.experimental.stepvector : vector<7xf32> + llvm.return %0 : vector<7xf32> +}