diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -808,6 +808,38 @@ "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; } +// +// LLVM masked operations. +// + +/// Create a call to Masked Load intrinsic. +def LLVM_MaskedLoadOp + : LLVM_OneResultOp<"intr.masked.load">, + Arguments<(ins LLVM_Type:$data, LLVM_Type:$mask, + Variadic:$pass_thru, I32Attr:$alignment)> { + string llvmBuilder = [{ + $res = $pass_thru.empty() ? builder.CreateMaskedLoad( + $data, llvm::Align($alignment.getZExtValue()), $mask) : + builder.CreateMaskedLoad( + $data, llvm::Align($alignment.getZExtValue()), $mask, $pass_thru[0]); + }]; + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; +} + +/// Create a call to Masked Store intrinsic. +def LLVM_MaskedStoreOp + : LLVM_ZeroResultOp<"intr.masked.store">, + Arguments<(ins LLVM_Type:$value, LLVM_Type:$data, LLVM_Type:$mask, + I32Attr:$alignment)> { + string llvmBuilder = [{ + builder.CreateMaskedStore( + $value, $data, llvm::Align($alignment.getZExtValue()), $mask); + }]; + let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` " + "type($value) `,` type($mask) `into` type($data)"; +} + // // Atomic operations. // diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s +// RUN: mlir-translate -mlir-to-llvmir %s +//| FileCheck %s // CHECK-LABEL: @intrinsics llvm.func @intrinsics(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x float>">, %arg3: !llvm<"i8*">) { @@ -143,6 +144,20 @@ llvm.return %C: !llvm<"<12 x float>"> } +// CHECK-LABEL: @masked_intrinsics +llvm.func @masked_intrinsics(%A: !llvm<"<7 x float>*">, %mask: !llvm<"<7 x i1>">) { + // CHECK: call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef) + %a = llvm.intr.masked.load %A, %mask { alignment = 1: i32} : + (!llvm<"<7 x float>*">, !llvm<"<7 x i1>">) -> !llvm<"<7 x float>"> + // CHECK: call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> %{{.*}}) + %b = llvm.intr.masked.load %A, %mask, %a { alignment = 1: i32} : + (!llvm<"<7 x float>*">, !llvm<"<7 x i1>">, !llvm<"<7 x float>">) -> !llvm<"<7 x float>"> + // CHECK: call void @llvm.masked.store.v7f32.p0v7f32(<7 x float> %{{.*}}, <7 x float>* %0, i32 {{.*}}, <7 x i1> %{{.*}}) + llvm.intr.masked.store %b, %A, %mask { alignment = 1: i32} : + !llvm<"<7 x float>">, !llvm<"<7 x i1>"> into !llvm<"<7 x float>*"> + llvm.return +} + // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 @@ -167,3 +182,5 @@ // CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0 // CHECK-DAG: declare float @llvm.copysign.f32(float, float) // CHECK-DAG: declare <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float>, <48 x float>, i32 immarg, i32 immarg, i32 immarg) +// CHECK-DAG: declare <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>*, i32 immarg, <7 x i1>, <7 x float>) +// CHECK-DAG: declare void @llvm.masked.store.v7f32.p0v7f32(<7 x float>, <7 x float>*, i32 immarg, <7 x i1>)