diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -400,60 +400,84 @@ } /// Create a call to Masked Load intrinsic. -def LLVM_MaskedLoadOp : LLVM_Op<"intr.masked.load"> { +def LLVM_MaskedLoadOp : LLVM_OneResultIntrOp<"masked.load"> { let arguments = (ins LLVM_Type:$data, LLVM_Type:$mask, Variadic:$pass_thru, I32Attr:$alignment); let results = (outs LLVM_AnyVector:$res); + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; + string llvmBuilder = [{ $res = $pass_thru.empty() ? builder.CreateMaskedLoad( $_resultType, $data, llvm::Align($alignment), $mask) : builder.CreateMaskedLoad( $_resultType, $data, llvm::Align($alignment), $mask, $pass_thru[0]); }]; - let assemblyFormat = - "operands attr-dict `:` functional-type(operands, results)"; + string mlirBuilder = [{ + $res = $_builder.create($_location, + $_resultType, $data, $mask, $pass_thru, $_int_attr($alignment)); + }]; + list llvmArgIndices = [0, 2, 3, 1]; } /// Create a call to Masked Store intrinsic. -def LLVM_MaskedStoreOp : LLVM_Op<"intr.masked.store"> { +def LLVM_MaskedStoreOp : LLVM_ZeroResultIntrOp<"masked.store"> { let arguments = (ins LLVM_Type:$value, LLVM_Type:$data, LLVM_Type:$mask, I32Attr:$alignment); let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder]; + let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` " + "type($value) `,` type($mask) `into` type($data)"; + string llvmBuilder = [{ builder.CreateMaskedStore( $value, $data, llvm::Align($alignment), $mask); }]; - let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` " - "type($value) `,` type($mask) `into` type($data)"; + string mlirBuilder = [{ + $_builder.create($_location, + $value, $data, $mask, $_int_attr($alignment)); + }]; + list llvmArgIndices = [0, 1, 3, 2]; } /// Create a call to Masked Gather intrinsic. -def LLVM_masked_gather : LLVM_Op<"intr.masked.gather"> { +def LLVM_masked_gather : LLVM_OneResultIntrOp<"masked.gather"> { let arguments = (ins LLVM_AnyVector:$ptrs, LLVM_Type:$mask, Variadic:$pass_thru, I32Attr:$alignment); let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; + string llvmBuilder = [{ $res = $pass_thru.empty() ? builder.CreateMaskedGather( $_resultType, $ptrs, llvm::Align($alignment), $mask) : builder.CreateMaskedGather( $_resultType, $ptrs, llvm::Align($alignment), $mask, $pass_thru[0]); }]; - let assemblyFormat = - "operands attr-dict `:` functional-type(operands, results)"; + string mlirBuilder = [{ + $res = $_builder.create($_location, + $_resultType, $ptrs, $mask, $pass_thru, $_int_attr($alignment)); + }]; + list llvmArgIndices = [0, 2, 3, 1]; } /// Create a call to Masked Scatter intrinsic. -def LLVM_masked_scatter : LLVM_Op<"intr.masked.scatter"> { +def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> { let arguments = (ins LLVM_Type:$value, LLVM_Type:$ptrs, LLVM_Type:$mask, I32Attr:$alignment); let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder]; + let assemblyFormat = "$value `,` $ptrs `,` $mask attr-dict `:` " + "type($value) `,` type($mask) `into` type($ptrs)"; + string llvmBuilder = [{ builder.CreateMaskedScatter( $value, $ptrs, llvm::Align($alignment), $mask); }]; - let assemblyFormat = "$value `,` $ptrs `,` $mask attr-dict `:` " - "type($value) `,` type($mask) `into` type($ptrs)"; + string mlirBuilder = [{ + $_builder.create($_location, + $value, $ptrs, $mask, $_int_attr($alignment)); + }]; + list llvmArgIndices = [0, 1, 3, 2]; } /// Create a call to Masked Expand Load intrinsic. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -228,7 +228,6 @@ // - $_builder - substituted with the MLIR builder; // - $_qualCppClassName - substitiuted with the MLIR operation class name. // Additionally, `$$` can be used to produce the dollar character. - // FIXME: The $name variable resolution does not support variadic arguments. string mlirBuilder = ""; // An array that specifies a mapping from MLIR argument indices to LLVM IR diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -243,6 +243,13 @@ return position; } +/// Drops the first `n` elements of the `values` array. +static SmallVector dropFront(ArrayRef values, + int64_t n) { + SmallVector result(values.drop_front(n)); + return result; +} + DataLayoutSpecInterface mlir::translateDataLayout(const llvm::DataLayout &dataLayout, MLIRContext *context) { diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -278,19 +278,35 @@ ret <7 x i1> %3 } -; TODO: masked load store intrinsics should be handled specially. -define void @masked_load_store_intrinsics(<7 x float>* %0, <7 x i1> %1) { - %3 = call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %0, i32 1, <7 x i1> %1, <7 x float> undef) - %4 = call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %0, i32 1, <7 x i1> %1, <7 x float> %3) - call void @llvm.masked.store.v7f32.p0v7f32(<7 x float> %4, <7 x float>* %0, i32 1, <7 x i1> %1) - ret void -} - -; TODO: masked gather scatter intrinsics should be handled specially. -define void @masked_gather_scatter_intrinsics(<7 x float*> %0, <7 x i1> %1) { - %3 = call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %0, i32 1, <7 x i1> %1, <7 x float> undef) - %4 = call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %0, i32 1, <7 x i1> %1, <7 x float> %3) - call void @llvm.masked.scatter.v7f32.v7p0f32(<7 x float> %4, <7 x float*> %0, i32 1, <7 x i1> %1) +; CHECK-LABEL: @masked_load_store_intrinsics +; CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]] +define void @masked_load_store_intrinsics(<7 x float>* %vec, <7 x i1> %mask) { + ; CHECK: %[[UNDEF:.+]] = llvm.mlir.undef + ; CHECK: %[[VAL1:.+]] = llvm.intr.masked.load %[[VEC]], %[[MASK]], %[[UNDEF]] {alignment = 1 : i32} + ; CHECK-SAME: (!llvm.ptr>, vector<7xi1>, vector<7xf32>) -> vector<7xf32> + %1 = call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %vec, i32 1, <7 x i1> %mask, <7 x float> undef) + ; CHECK: %[[VAL2:.+]] = llvm.intr.masked.load %[[VEC]], %[[MASK]], %[[VAL1]] {alignment = 4 : i32} + %2 = call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %vec, i32 4, <7 x i1> %mask, <7 x float> %1) + ; CHECK: llvm.intr.masked.store %[[VAL2]], %[[VEC]], %[[MASK]] {alignment = 8 : i32} + ; CHECK-SAME: vector<7xf32>, vector<7xi1> into !llvm.ptr> + call void @llvm.masked.store.v7f32.p0v7f32(<7 x float> %2, <7 x float>* %vec, i32 8, <7 x i1> %mask) + ret void +} + +; CHECK-LABEL: @masked_gather_scatter_intrinsics +; CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]] +define void @masked_gather_scatter_intrinsics(<7 x float*> %vec, <7 x i1> %mask) { + ; CHECK: %[[UNDEF:.+]] = llvm.mlir.undef + ; CHECK: %[[VAL1:.+]] = llvm.intr.masked.gather %[[VEC]], %[[MASK]], %[[UNDEF]] {alignment = 1 : i32} + ; CHECK-SAME: (!llvm.vec<7 x ptr>, vector<7xi1>, vector<7xf32>) -> vector<7xf32> + %1 = call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %vec, i32 1, <7 x i1> %mask, <7 x float> undef) + ; CHECK: %[[VAL2:.+]] = llvm.intr.masked.gather %[[VEC]], %[[MASK]], %[[VAL1]] {alignment = 4 : i32} + %2 = call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %vec, i32 4, <7 x i1> %mask, <7 x float> %1) + ; CHECK: llvm.intr.masked.scatter %[[VAL2]], %[[VEC]], %[[MASK]] {alignment = 8 : i32} + ; CHECK-SAME: vector<7xf32>, vector<7xi1> into !llvm.vec<7 x ptr> + call void @llvm.masked.scatter.v7f32.v7p0f32(<7 x float> %2, <7 x float*> %vec, i32 8, <7 x i1> %mask) ret void } diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -66,14 +66,14 @@ return {startPos, endPos - startPos}; } -// Check if `name` is the name of the variadic operand of `op`. The variadic -// operand can only appear at the last position in the list of operands. +// Check if `name` is a variadic operand of `op`. Seach all operands since the +// MLIR and LLVM IR operand order may differ and only for the latter the +// variadic operand is guaranteed to be at the end of the operands list. static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) { - unsigned numOperands = op.getNumOperands(); - if (numOperands == 0) - return false; - const auto &operand = op.getOperand(numOperands - 1); - return operand.isVariableLength() && operand.name == name; + for (int i = 0, e = op.getNumOperands(); i < e; ++i) + if (op.getOperand(i).name == name) + return op.getOperand(i).isVariableLength(); + return false; } // Check if `result` is a known name of a result of `op`. @@ -227,12 +227,15 @@ if (succeeded(argIndex)) { // Access the LLVM IR operand that maps to the given argument index using // the provided argument indices mapping. - // FIXME: support trailing variadic arguments. - int64_t operandIdx = llvmArgIndices[*argIndex]; - if (operandIdx < 0) + int64_t idx = llvmArgIndices[*argIndex]; + if (idx < 0) PrintFatalError(record.getLoc(), "expected non-negative operand index"); - assert(!isVariadicOperandName(op, name) && "unexpected variadic operand"); - bs << formatv("processValue(llvmOperands[{0}])", operandIdx); + bool isVariadicOperand = isVariadicOperandName(op, name); + auto result = + isVariadicOperand + ? formatv("processValues(dropFront(llvmOperands, {0}))", idx) + : formatv("processValue(llvmOperands[{0}])", idx); + bs << result; } else if (isResultName(op, name)) { if (op.getNumResults() != 1) PrintFatalError(record.getLoc(), "expected op to have one result");