diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1146,6 +1146,7 @@ // all contraction operations. Also applies folding and DCE. { OwningRewritePatternList patterns; + populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); populateVectorSlicesLoweringPatterns(patterns, &getContext()); populateVectorContractLoweringPatterns(patterns, &getContext()); applyPatternsAndFoldGreedily(getOperation(), patterns); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1236,6 +1236,45 @@ } }; +class CreateMaskOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::CreateMaskOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstType = op.getResult().getType().cast(); + auto eltType = dstType.getElementType(); + int64_t rank = dstType.getRank(); + Value idx = op.getOperand(0); + + Value trueVal; + Value falseVal; + if (rank > 1) { + VectorType lowType = + VectorType::get(dstType.getShape().drop_front(), eltType); + trueVal = rewriter.create( + loc, lowType, op.getOperands().drop_front()); + falseVal = rewriter.create(loc, lowType, + rewriter.getZeroAttr(lowType)); + } + + Value result = rewriter.create(loc, dstType, + rewriter.getZeroAttr(dstType)); + for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; d++) { + Value bnd = rewriter.create(loc, rewriter.getIndexAttr(d)); + Value val = rewriter.create(loc, CmpIPredicate::slt, bnd, idx); + if (rank > 1) + val = rewriter.create(loc, val, trueVal, falseVal); + auto pos = rewriter.getI64ArrayAttr(d); + result = + rewriter.create(loc, dstType, val, result, pos); + } + rewriter.replaceOp(op, result); + return success(); + } +}; + /// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension @@ -1659,6 +1698,6 @@ patterns.insert(context); + ConstantMaskOpLowering, CreateMaskOpLowering>(context); patterns.insert(parameters, context); } diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -605,3 +605,49 @@ %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1> return %v: vector<2x3x4xi1> } + +// CHECK-LABEL: func @genbool_var_1d +// CHECK-SAME: %[[A:.*0]]: index +// CHECK-DAG: %[[VF:.*]] = constant dense : vector<3xi1> +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[A]] : index +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[VF]] [0] : i1 into vector<3xi1> +// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[A]] : index +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1> +// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[A]] : index +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1> +// CHECK: return %[[T5]] : vector<3xi1> + +func @genbool_var_1d(%arg0: index) -> vector<3xi1> { + %0 = vector.create_mask %arg0 : vector<3xi1> + return %0 : vector<3xi1> +} + +// CHECK-LABEL: func @genbool_var_2d +// CHECK-SAME: %[[A:.*0]]: index +// CHECK-SAME: %[[B:.*1]]: index +// CHECK-DAG: %[[Z1:.*]] = constant dense : vector<3xi1> +// CHECK-DAG: %[[Z2:.*]] = constant dense : vector<2x3xi1> +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[B]] : index +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z1]] [0] : i1 into vector<3xi1> +// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[B]] : index +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1> +// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[B]] : index +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1> +// CHECK: %[[T6:.*]] = cmpi "slt", %[[C0]], %[[A]] : index +// CHECK: %[[T7:.*]] = select %[[T6]], %[[T5]], %[[Z1]] : vector<3xi1> +// CHECK: %[[T8:.*]] = vector.insert %7, %[[Z2]] [0] : vector<3xi1> into vector<2x3xi1> +// CHECK: %[[T9:.*]] = cmpi "slt", %[[C1]], %[[A]] : index +// CHECK: %[[T10:.*]] = select %[[T9]], %[[T5]], %[[Z1]] : vector<3xi1> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T8]] [1] : vector<3xi1> into vector<2x3xi1> +// CHECK: return %[[T11]] : vector<2x3xi1> + +func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> { + %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1> + return %0 : vector<2x3xi1> +} diff --git a/mlir/test/Target/vector-to-llvm-ir.mlir b/mlir/test/Target/vector-to-llvm-ir.mlir --- a/mlir/test/Target/vector-to-llvm-ir.mlir +++ b/mlir/test/Target/vector-to-llvm-ir.mlir @@ -21,3 +21,11 @@ // CHECK-LABEL: @genbool_3d() // CHECK-NEXT: ret [2 x [3 x <4 x i1>]] {{\[+}}3 x <4 x i1>] [<4 x i1> , <4 x i1> zeroinitializer, <4 x i1> zeroinitializer], [3 x <4 x i1>] zeroinitializer] // note: awkward syntax to match [[ + +func @genbool_1d_var_but_constant() -> vector<8xi1> { + %i = constant 0 : index + %v = vector.create_mask %i : vector<8xi1> + return %v : vector<8xi1> +} +// CHECK-LABEL: @genbool_1d_var_but_constant() +// CHECK-NEXT: ret <8 x i1> zeroinitializer