diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1187,6 +1187,55 @@ } }; +/// Progressive lowering of ConstantMaskOp. +/// One: +/// %x = vector.constant_mask_op [a,b] +/// is replaced by: +/// %z = zero-result +/// %l = vector.constant_mask_op [b] +/// %4 = vector.insert %l, %z[0] +/// .. +/// %x = vector.insert %l, %..[a-1] +/// which will be folded at LLVM IR level. +class ConstantMaskOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ConstantMaskOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstType = op.getResult().getType().cast(); + auto eltType = dstType.getElementType(); + auto dimSizes = op.mask_dim_sizes(); + int64_t rank = dimSizes.size(); + int64_t trueDim = dimSizes[0].cast().getInt(); + + Value trueVal; + if (rank == 1) { + trueVal = rewriter.create( + loc, eltType, rewriter.getIntegerAttr(eltType, 1)); + } else { + VectorType lowType = + VectorType::get(dstType.getShape().drop_front(), eltType); + SmallVector newDimSizes; + for (int64_t r = 1; r < rank; r++) + newDimSizes.push_back(dimSizes[r].cast().getInt()); + trueVal = rewriter.create( + loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); + } + + Value result = rewriter.create(loc, dstType, + rewriter.getZeroAttr(dstType)); + for (int64_t d = 0; d < trueDim; d++) { + auto pos = rewriter.getI64ArrayAttr(d); + result = + rewriter.create(loc, dstType, trueVal, result, pos); + } + rewriter.replaceOp(op, result); + return success(); + } +}; + /// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension @@ -1609,6 +1658,7 @@ VectorTransformsOptions parameters) { patterns.insert(context); + TransposeOpLowering, OuterProductOpLowering, + ConstantMaskOpLowering>(context); patterns.insert(parameters, context); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -917,3 +917,20 @@ // CHECK-SAME: (!llvm<"float addrspace(3)*">, !llvm.i64) -> !llvm<"float addrspace(3)*"> // CHECK: %[[vecPtr_b:.*]] = llvm.addrspacecast %[[gep_b]] : // CHECK-SAME: !llvm<"float addrspace(3)*"> to !llvm<"<17 x float>*"> + +func @genbool_1d() -> vector<8xi1> { + %0 = vector.constant_mask [4] : vector<8xi1> + return %0 : vector<8xi1> +} +// CHECK-LABEL: func @genbool_1d +// CHECK: %[[T0:.*]] = llvm.mlir.constant(1 : i1) : !llvm.i1 +// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense : vector<8xi1>) : !llvm<"<8 x i1>"> +// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 +// CHECK: %[[T3:.*]] = llvm.insertelement %[[T0]], %[[T1]][%[[T2]] : !llvm.i64] : !llvm<"<8 x i1>"> +// CHECK: %[[T4:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 +// CHECK: %[[T5:.*]] = llvm.insertelement %[[T0]], %[[T3]][%[[T4]] : !llvm.i64] : !llvm<"<8 x i1>"> +// CHECK: %[[T6:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 +// CHECK: %[[T7:.*]] = llvm.insertelement %[[T0]], %[[T5]][%[[T6]] : !llvm.i64] : !llvm<"<8 x i1>"> +// CHECK: %[[T8:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64 +// CHECK: %[[T9:.*]] = llvm.insertelement %[[T0]], %[[T7]][%[[T8]] : !llvm.i64] : !llvm<"<8 x i1>"> +// CHECK: llvm.return %9 : !llvm<"<8 x i1>"> diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -559,3 +559,49 @@ %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> return %0 : vector<4x3x2xf32> } + +// CHECK-LABEL: func @genbool_1d +// CHECK: %[[TT:.*]] = constant 1 : i1 +// CHECK: %[[C1:.*]] = constant dense : vector<8xi1> +// CHECK: %[[T0.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<8xi1> +// CHECK: %[[T1.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<8xi1> +// CHECK: %[[T2.*]] = vector.insert %[[TT]], %[[T1]] [2] : i1 into vector<8xi1> +// CHECK: %[[T3.*]] = vector.insert %[[TT]], %[[T2]] [3] : i1 into vector<8xi1> +// CHECK: return %[[T3]] : vector<8xi1> + +func @genbool_1d() -> vector<8xi1> { + %0 = vector.constant_mask [4] : vector<8xi1> + return %0 : vector<8xi1> +} + +// CHECK-LABEL: func @genbool_2d +// CHECK: %[[TT:.*]] = constant 1 : i1 +// CHECK: %[[C1:.*]] = constant dense : vector<4xi1> +// CHECK: %[[C2:.*]] = constant dense : vector<4x4xi1> +// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<4xi1> +// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<4xi1> +// CHECK: %[[T2:.*]] = vector.insert %[[T1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> +// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[T2]] [1] : vector<4xi1> into vector<4x4xi1> +// CHECK: return %[[T3]] : vector<4x4xi1> + +func @genbool_2d() -> vector<4x4xi1> { + %v = vector.constant_mask [2, 2] : vector<4x4xi1> + return %v: vector<4x4xi1> +} + +// CHECK-LABEL: func @genbool_3d +// CHECK: %[[Tt:.*]] = constant 1 : i1 +// CHECK: %[[C1:.*]] = constant dense : vector<4xi1> +// CHECK: %[[C2:.*]] = constant dense : vector<3x4xi1> +// CHECK: %[[C3:.*]] = constant dense : vector<2x3x4xi1> +// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<4xi1> +// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<4xi1> +// CHECK: %[[T2:.*]] = vector.insert %[[TT]], %[[T1]] [2] : i1 into vector<4xi1> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> +// CHECK: return %[[T4]] : vector<2x3x4xi1> + +func @genbool_3d() -> vector<2x3x4xi1> { + %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1> + return %v: vector<2x3x4xi1> +} diff --git a/mlir/test/Target/vector-to-llvm-ir.mlir b/mlir/test/Target/vector-to-llvm-ir.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/vector-to-llvm-ir.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s + +func @genbool_1d() -> vector<8xi1> { + %0 = vector.constant_mask [4] : vector<8xi1> + return %0 : vector<8xi1> +} +// CHECK-LABEL: @genbool_1d() +// CHECK-NEXT: ret <8 x i1> + +func @genbool_2d() -> vector<4x4xi1> { + %v = vector.constant_mask [2, 2] : vector<4x4xi1> + return %v: vector<4x4xi1> +} +// CHECK-LABEL: @genbool_2d() +// CHECK-NEXT: ret [4 x <4 x i1>] [<4 x i1> , <4 x i1> , <4 x i1> zeroinitializer, <4 x i1> zeroinitializer] + +func @genbool_3d() -> vector<2x3x4xi1> { + %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1> + return %v: vector<2x3x4xi1> +} +// CHECK-LABEL: @genbool_3d() +// CHECK-NEXT: ret [2 x [3 x <4 x i1>]] {{\[+}}3 x <4 x i1>] [<4 x i1> , <4 x i1> zeroinitializer, <4 x i1> zeroinitializer], [3 x <4 x i1>] zeroinitializer] +// note: awkward syntax to match [[