diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1150,6 +1150,103 @@ let hasFolder = 1; } +def Vector_MaskedLoadOp : + Vector_Op<"maskedload">, + Arguments<(ins AnyMemRef:$base, + VectorOfRankAndType<[1], [I1]>:$mask, + VectorOfRank<[1]>:$pass_thru)>, + Results<(outs VectorOfRank<[1]>:$result)> { + + let summary = "loads elements from memory into a vector as defined by a mask vector"; + + let description = [{ + The masked load reads elements from memory into a 1-D vector as defined + by a base and a 1-D mask vector. When the mask is set, the element is read + from memory. Otherwise, the corresponding element is taken from a 1-D + pass-through vector. Informally the semantics are: + ``` + result[0] := mask[0] ? MEM[base+0] : pass_thru[0] + result[1] := mask[1] ? MEM[base+1] : pass_thru[1] + etc. + ``` + The masked load can be used directly where applicable, or can be used + during progressively lowering to bring other memory operations closer to + hardware ISA support for a masked load. The semantics of the operation + closely correspond to those of the `llvm.masked.load` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-load-intrinsics). + + Example: + + ```mlir + %0 = vector.maskedload %base, %mask, %pass_thru + : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + ``` + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getMaskVectorType() { + return mask().getType().cast(); + } + VectorType getPassThruVectorType() { + return pass_thru().getType().cast(); + } + VectorType getResultVectorType() { + return result().getType().cast(); + } + }]; + let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` " + "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; +} + +def Vector_MaskedStoreOp : + Vector_Op<"maskedstore">, + Arguments<(ins AnyMemRef:$base, + VectorOfRankAndType<[1], [I1]>:$mask, + VectorOfRank<[1]>:$value)> { + + let summary = "stores elements from a vector into memoery as defined by a mask vector"; + + let description = [{ + The masked store operation writes elements from a 1-D vector into memory + as defined by a base and a 1-D mask vector. When the mask is set, the + corresponding element from the vector is written to memory. Otherwise, + no action is taken for the element. Informally the semantics are: + ``` + index = base + if (mask[0]) MEM[base+0] = value[0] + if (mask[1]) MEM[base+1] = value[1] + etc. + ``` + The masked store can be used directly where applicable, or can be used + during progressively lowering to bring other memory operations closer to + hardware ISA support for a masked store. The semantics of the operation + closely correspond to those of the `llvm.masked.store` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics). + + Example: + + ```mlir + vector.maskedstore %base, %mask, %value + : memref, vector<8xi1>, vector<8xf32> + ``` + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getMaskVectorType() { + return mask().getType().cast(); + } + VectorType getValueVectorType() { + return value().getType().cast(); + } + }]; + let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` " + "type($mask) `,` type($value) `into` type($base)"; +} + def Vector_GatherOp : Vector_Op<"gather">, Arguments<(ins AnyMemRef:$base, diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @maskedload16(%base: memref, %mask: vector<16xi1>, + %pass_thru: vector<16xf32>) -> vector<16xf32> { + %ld = vector.maskedload %base, %mask, %pass_thru + : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %ld : vector<16xf32> +} + +func @entry() { + // Set up memory. + %c0 = constant 0: index + %c1 = constant 1: index + %c16 = constant 16: index + %A = alloc(%c16) : memref + scf.for %i = %c0 to %c16 step %c1 { + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + store %fi, %A[%i] : memref + } + + // Set up pass thru vector. + %u = constant -7.0: f32 + %pass = vector.broadcast %u : f32 to vector<16xf32> + + // Set up masks. + %f = constant 0: i1 + %t = constant 1: i1 + %none = vector.constant_mask [0] : vector<16xi1> + %all = vector.constant_mask [16] : vector<16xi1> + %some = vector.constant_mask [8] : vector<16xi1> + %0 = vector.insert %f, %some[0] : i1 into vector<16xi1> + %1 = vector.insert %t, %0[13] : i1 into vector<16xi1> + %2 = vector.insert %t, %1[14] : i1 into vector<16xi1> + %other = vector.insert %t, %2[14] : i1 into vector<16xi1> + + // + // Masked load tests. + // + + %l1 = call @maskedload16(%A, %none, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %l1 : vector<16xf32> + // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 ) + + %l2 = call @maskedload16(%A, %all, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %l2 : vector<16xf32> + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + + %l3 = call @maskedload16(%A, %some, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %l3 : vector<16xf32> + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, -7, -7, -7 ) + + %l4 = call @maskedload16(%A, %other, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %l4 : vector<16xf32> + // CHECK: ( -7, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, 13, 14, -7 ) + + return +} + diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @maskedstore16(%base: memref, + %mask: vector<16xi1>, %value: vector<16xf32>) { + vector.maskedstore %base, %mask, %value + : vector<16xi1>, vector<16xf32> into memref + return +} + +func @printmem16(%A: memref) { + %c0 = constant 0: index + %c1 = constant 1: index + %c16 = constant 16: index + %z = constant 0.0: f32 + %m = vector.broadcast %z : f32 to vector<16xf32> + %mem = scf.for %i = %c0 to %c16 step %c1 + iter_args(%m_iter = %m) -> (vector<16xf32>) { + %c = load %A[%i] : memref + %i32 = index_cast %i : index to i32 + %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + scf.yield %m_new : vector<16xf32> + } + vector.print %mem : vector<16xf32> + return +} + +func @entry() { + // Set up memory. + %f0 = constant 0.0: f32 + %c0 = constant 0: index + %c1 = constant 1: index + %c16 = constant 16: index + %A = alloc(%c16) : memref + scf.for %i = %c0 to %c16 step %c1 { + store %f0, %A[%i] : memref + } + + // Set up value vector. + %v = vector.broadcast %f0 : f32 to vector<16xf32> + %val = scf.for %i = %c0 to %c16 step %c1 + iter_args(%v_iter = %v) -> (vector<16xf32>) { + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + scf.yield %v_new : vector<16xf32> + } + + // Set up masks. + %t = constant 1: i1 + %none = vector.constant_mask [0] : vector<16xi1> + %some = vector.constant_mask [8] : vector<16xi1> + %more = vector.insert %t, %some[13] : i1 into vector<16xi1> + %all = vector.constant_mask [16] : vector<16xi1> + + // + // Masked store tests. + // + + vector.print %val : vector<16xf32> + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + + call @printmem16(%A) : (memref) -> () + // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + + call @maskedstore16(%A, %none, %val) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + + call @maskedstore16(%A, %some, %val) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0 ) + + call @maskedstore16(%A, %more, %val) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 13, 0, 0 ) + + call @maskedstore16(%A, %all, %val) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + + return +} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -162,6 +162,19 @@ return success(); } +// Helper that returns a bit-casted pointer given a memref base. +LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc, + Value memref, MemRefType memRefType, Type type, + Value &ptr) { + Value base; + if (failed(getBase(rewriter, loc, memref, memRefType, base))) + return failure(); + auto pType = type.template cast().getPointerTo(); + base = rewriter.create(loc, pType, base); + ptr = rewriter.create(loc, pType, base); + return success(); +} + // Helper that returns vector of pointers given a memref base and an index // vector. LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, @@ -297,6 +310,72 @@ } }; +/// Conversion pattern for a vector.maskedload. +class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorMaskedLoadOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto load = cast(op); + auto adaptor = vector::MaskedLoadOpAdaptor(operands); + + // Resolve alignment. + unsigned align; + if (failed(getMemRefAlignment(typeConverter, load, align))) + return failure(); + + auto vtype = typeConverter.convertType(load.getResultVectorType()); + Value ptr; + if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), + vtype, ptr))) + return failure(); + + rewriter.replaceOpWithNewOp( + load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), + rewriter.getI32IntegerAttr(align)); + return success(); + } +}; + +/// Conversion pattern for a vector.maskedstore. +class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorMaskedStoreOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto store = cast(op); + auto adaptor = vector::MaskedStoreOpAdaptor(operands); + + // Resolve alignment. + unsigned align; + if (failed(getMemRefAlignment(typeConverter, store, align))) + return failure(); + + auto vtype = typeConverter.convertType(store.getValueVectorType()); + Value ptr; + if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), + vtype, ptr))) + return failure(); + + rewriter.replaceOpWithNewOp( + store, adaptor.value(), ptr, adaptor.mask(), + rewriter.getI32IntegerAttr(align)); + return success(); + } +}; + /// Conversion pattern for a vector.gather. class VectorGatherOpConversion : public ConvertToLLVMPattern { public: @@ -1341,6 +1420,8 @@ VectorTransferConversion, VectorTransferConversion, VectorTypeCastOpConversion, + VectorMaskedLoadOpConversion, + VectorMaskedStoreOpConversion, VectorGatherOpConversion, VectorScatterOpConversion, VectorExpandLoadOpConversion, diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1855,6 +1855,41 @@ return llvm::to_vector<4>(getVectorType().getShape()); } +//===----------------------------------------------------------------------===// +// MaskedLoadOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(MaskedLoadOp op) { + VectorType maskVType = op.getMaskVectorType(); + VectorType passVType = op.getPassThruVectorType(); + VectorType resVType = op.getResultVectorType(); + + if (resVType.getElementType() != op.getMemRefType().getElementType()) + return op.emitOpError("base and result element type should match"); + + if (resVType.getDimSize(0) != maskVType.getDimSize(0)) + return op.emitOpError("expected result dim to match mask dim"); + if (resVType != passVType) + return op.emitOpError("expected pass_thru of same type as result type"); + return success(); +} + +//===----------------------------------------------------------------------===// +// MaskedStoreOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(MaskedStoreOp op) { + VectorType maskVType = op.getMaskVectorType(); + VectorType valueVType = op.getValueVectorType(); + + if (valueVType.getElementType() != op.getMemRefType().getElementType()) + return op.emitOpError("base and value element type should match"); + + if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) + return op.emitOpError("expected value dim to match mask dim"); + return success(); +} + //===----------------------------------------------------------------------===// // GatherOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -970,6 +970,26 @@ // CHECK-SAME: !llvm.vec<16 x float> into !llvm.vec<16 x float> // CHECK: llvm.return %[[T]] : !llvm.vec<16 x float> +func @masked_load_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> { + %0 = vector.maskedload %arg0, %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @masked_load_op +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr>) -> !llvm.ptr> +// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr>, !llvm.vec<16 x i1>, !llvm.vec<16 x float>) -> !llvm.vec<16 x float> +// CHECK: llvm.return %[[L]] : !llvm.vec<16 x float> + +func @masked_store_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) { + vector.maskedstore %arg0, %arg1, %arg2 : vector<16xi1>, vector<16xf32> into memref + return +} + +// CHECK-LABEL: func @masked_store_op +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr>) -> !llvm.ptr> +// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x float>, !llvm.vec<16 x i1> into !llvm.ptr> +// CHECK: llvm.return + func @gather_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> return %0 : vector<3xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1180,6 +1180,41 @@ // ----- +func @maskedload_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xf32>) { + // expected-error@+1 {{'vector.maskedload' op base and result element type should match}} + %0 = vector.maskedload %base, %mask, %pass : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + +func @maskedload_dim_mask_mismatch(%base: memref, %mask: vector<15xi1>, %pass: vector<16xf32>) { + // expected-error@+1 {{'vector.maskedload' op expected result dim to match mask dim}} + %0 = vector.maskedload %base, %mask, %pass : memref, vector<15xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + +func @maskedload_pass_thru_type_mask_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xi32>) { + // expected-error@+1 {{'vector.maskedload' op expected pass_thru of same type as result type}} + %0 = vector.maskedload %base, %mask, %pass : memref, vector<16xi1>, vector<16xi32> into vector<16xf32> +} + +// ----- + +func @maskedstore_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.maskedstore' op base and value element type should match}} + vector.maskedstore %base, %mask, %value : vector<16xi1>, vector<16xf32> into memref +} + +// ----- + +func @maskedstore_dim_mask_mismatch(%base: memref, %mask: vector<15xi1>, %value: vector<16xf32>) { + // expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}} + vector.maskedstore %base, %mask, %value : vector<15xi1>, vector<16xf32> into memref +} + +// ----- + func @gather_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { // expected-error@+1 {{'vector.gather' op base and result element type should match}} %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -369,6 +369,15 @@ return %0 : vector<16xi32> } +// CHECK-LABEL: @masked_load_and_store +func @masked_load_and_store(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { + // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}, %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.maskedload %base, %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: vector.maskedstore %{{.*}}, %{{.*}}, %[[X]] : vector<16xi1>, vector<16xf32> into memref + vector.maskedstore %base, %mask, %0 : vector<16xi1>, vector<16xf32> into memref + return +} + // CHECK-LABEL: @gather_and_scatter func @gather_and_scatter(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32>