diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1072,6 +1072,77 @@ const bool force32BitVectorIndices; }; +/// Returns true if all the `i1` elements of `constantOp` are set to `value`. +static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { + auto denseAttr = dyn_cast(constantOp.getValue()); + // TODO: Support non-dense constant. + if (!denseAttr) + return false; + + assert(denseAttr.getElementType().isInteger(1) && "Unexpected type"); + + for (bool i1Value : denseAttr.getValues()) { + if (i1Value != value) + return false; + } + + return true; +} + +/// Folds a select operation between an all-true and all-false vector. For now, +/// only single element vectors (i.e., vector<1xi1>) are supported. That is: +/// +/// %true = arith.constant dense : vector<1xi1> +/// %false = arith.constant dense : vector<1xi1> +/// %result = arith.select %cond, %true, %false : i1, vector<1xi1> +/// => +/// %result = vector.insert %cond ... : i1 into vector<1xi1> +/// +/// InstCombine seems to handle vectors with multiple elements but not the +/// single element ones. +struct FoldI1Select : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SelectOp selectOp, + PatternRewriter &rewriter) const override { + auto vecType = dyn_cast(selectOp.getType()); + if (!vecType || !vecType.getElementType().isInteger(1)) + return failure(); + + // Only scalar conditions can be sign-extended. + Value cond = selectOp.getCondition(); + if (isa(cond.getType())) + return failure(); + + // TODO: Support n-D and scalable vectors. + if (vecType.getRank() != 1 || vecType.isScalable()) + return failure(); + + // TODO: Support vectors with multiple elements. + if (vecType.getShape()[0] != 1) + return failure(); + + auto trueConst = selectOp.getTrueValue().getDefiningOp(); + if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true)) + return failure(); + + auto falseConst = + selectOp.getFalseValue().getDefiningOp(); + if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false)) + return failure(); + + // Replace select with its condition inserted into a single element vector. + Location loc = selectOp.getLoc(); + auto elemType = rewriter.getIntegerType(vecType.getNumElements()); + auto insertType = VectorType::get(/*shape=*/{1}, elemType); + auto constOp = rewriter.createOrFold( + loc, rewriter.getZeroAttr(insertType)); + rewriter.replaceOpWithNewOp(selectOp, cond, constOp, + /*pos=*/0); + return success(); + } +}; + // Drop inner most contiguous unit dimensions from transfer_read operand. class DropInnerMostUnitDims : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1327,6 +1398,7 @@ MaterializeTransferMask, MaterializeTransferMask>( patterns.getContext(), force32BitVectorIndices, benefit); + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, diff --git a/mlir/test/Dialect/Vector/vector-materialize-mask.mlir b/mlir/test/Dialect/Vector/vector-materialize-mask.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-materialize-mask.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +func.func @select_single_i1_vector(%cond : i1) -> vector<1xi1> { + %true = arith.constant dense : vector<1xi1> + %false = arith.constant dense : vector<1xi1> + %select = arith.select %cond, %true, %false : i1, vector<1xi1> + return %select : vector<1xi1> +} + +transform.sequence failures(propagate) { +^bb1(%func_op: !transform.op<"func.func">): + transform.apply_patterns to %func_op { + transform.apply_patterns.vector.materialize_masks + } : !transform.op<"func.func"> +} + +// CHECK-LABEL: func @select_single_i1_vector +// CHECK-SAME: %[[COND:.*]]: i1 +// CHECK: %[[INSERT:.*]] = vector.insert %[[COND]], %{{.*}} [0] : i1 into vector<1xi1> +// CHECK: return %[[INSERT]] : vector<1xi1>