diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1328,6 +1328,8 @@ int64_t trueDim = dimSizes[0].cast().getInt(); if (rank == 1) { + // Express constant 1-D case in explicit vector form: + // [T,..,T,F,..,F]. SmallVector values(dstType.getDimSize(0)); for (int64_t d = 0; d < trueDim; d++) values[d] = true; @@ -1364,8 +1366,7 @@ /// %1 = select %0, %l, %zeroes | /// %r = vector.insert %1, %pr [i] | d-times /// %x = .... -/// When rank == 1, the selection operator is not needed, -/// and we can assign the true/false value right away. +/// until a one-dimensional vector is reached. class CreateMaskOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1375,30 +1376,41 @@ auto loc = op.getLoc(); auto dstType = op.getResult().getType().cast(); auto eltType = dstType.getElementType(); + int64_t dim = dstType.getDimSize(0); int64_t rank = dstType.getRank(); Value idx = op.getOperand(0); - Value trueVal; - Value falseVal; - if (rank > 1) { - VectorType lowType = - VectorType::get(dstType.getShape().drop_front(), eltType); - trueVal = rewriter.create( - loc, lowType, op.getOperands().drop_front()); - falseVal = rewriter.create(loc, lowType, - rewriter.getZeroAttr(lowType)); + if (rank == 1) { + // Express dynamic 1-D case in explicit vector form: + // mask = [0,1,..,n-1] < [a,a,..,a] + SmallVector values(dim); + for (int64_t d = 0; d < dim; d++) + values[d] = d; + Value indices = + rewriter.create(loc, rewriter.getI64VectorAttr(values)); + Value bound = + rewriter.create(loc, rewriter.getI64Type(), idx); + Value bounds = rewriter.create(loc, indices.getType(), bound); + rewriter.replaceOpWithNewOp(op, CmpIPredicate::slt, indices, + bounds); + return success(); } + VectorType lowType = + VectorType::get(dstType.getShape().drop_front(), eltType); + Value trueVal = rewriter.create( + loc, lowType, op.getOperands().drop_front()); + Value falseVal = rewriter.create(loc, lowType, + rewriter.getZeroAttr(lowType)); Value result = rewriter.create(loc, dstType, rewriter.getZeroAttr(dstType)); - for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; d++) { + for (int64_t d = 0; d < dim; d++) { Value bnd = rewriter.create(loc, rewriter.getIndexAttr(d)); Value val = rewriter.create(loc, CmpIPredicate::slt, bnd, idx); - if (rank > 1) - val = rewriter.create(loc, val, trueVal, falseVal); + Value sel = rewriter.create(loc, val, trueVal, falseVal); auto pos = rewriter.getI64ArrayAttr(d); result = - rewriter.create(loc, dstType, val, result, pos); + rewriter.create(loc, dstType, sel, result, pos); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -710,18 +710,12 @@ } // CHECK-LABEL: func @genbool_var_1d -// CHECK-SAME: %[[A:.*0]]: index -// CHECK-DAG: %[[VF:.*]] = constant dense : vector<3xi1> -// CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK-DAG: %[[C1:.*]] = constant 1 : index -// CHECK-DAG: %[[C2:.*]] = constant 2 : index -// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[A]] : index -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[VF]] [0] : i1 into vector<3xi1> -// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[A]] : index -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1> -// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[A]] : index -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1> -// CHECK: return %[[T5]] : vector<3xi1> +// CHECK-SAME: %[[A:.*]]: index +// CHECK: %[[C1:.*]] = constant dense<[0, 1, 2]> : vector<3xi64> +// CHECK: %[[T0:.*]] = index_cast %[[A]] : index to i64 +// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64> +// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[T1]] : vector<3xi64> +// CHECK: return %[[T2]] : vector<3xi1> func @genbool_var_1d(%arg0: index) -> vector<3xi1> { %0 = vector.create_mask %arg0 : vector<3xi1> @@ -731,24 +725,21 @@ // CHECK-LABEL: func @genbool_var_2d // CHECK-SAME: %[[A:.*0]]: index // CHECK-SAME: %[[B:.*1]]: index -// CHECK-DAG: %[[Z1:.*]] = constant dense : vector<3xi1> -// CHECK-DAG: %[[Z2:.*]] = constant dense : vector<2x3xi1> -// CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK-DAG: %[[C1:.*]] = constant 1 : index -// CHECK-DAG: %[[C2:.*]] = constant 2 : index -// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[B]] : index -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z1]] [0] : i1 into vector<3xi1> -// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[B]] : index -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1> -// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[B]] : index -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1> -// CHECK: %[[T6:.*]] = cmpi "slt", %[[C0]], %[[A]] : index -// CHECK: %[[T7:.*]] = select %[[T6]], %[[T5]], %[[Z1]] : vector<3xi1> -// CHECK: %[[T8:.*]] = vector.insert %7, %[[Z2]] [0] : vector<3xi1> into vector<2x3xi1> -// CHECK: %[[T9:.*]] = cmpi "slt", %[[C1]], %[[A]] : index -// CHECK: %[[T10:.*]] = select %[[T9]], %[[T5]], %[[Z1]] : vector<3xi1> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T8]] [1] : vector<3xi1> into vector<2x3xi1> -// CHECK: return %[[T11]] : vector<2x3xi1> +// CHECK: %[[CI:.*]] = constant dense<[0, 1, 2]> : vector<3xi64> +// CHECK: %[[CF:.*]] = constant dense : vector<3xi1> +// CHECK: %[[C2:.*]] = constant dense : vector<2x3xi1> +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK: %[[c1:.*]] = constant 1 : index +// CHECK: %[[T0:.*]] = index_cast %[[B]] : index to i64 +// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64> +// CHECK: %[[T2:.*]] = cmpi "slt", %[[CI]], %[[T1]] : vector<3xi64> +// CHECK: %[[T3:.*]] = cmpi "slt", %[[c0]], %[[A]] : index +// CHECK: %[[T4:.*]] = select %[[T3]], %[[T2]], %[[CF]] : vector<3xi1> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> +// CHECK: %[[T6:.*]] = cmpi "slt", %[[c1]], %[[A]] : index +// CHECK: %[[T7:.*]] = select %[[T6]], %[[T2]], %[[CF]] : vector<3xi1> +// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T5]] [1] : vector<3xi1> into vector<2x3xi1> +// CHECK: return %[[T8]] : vector<2x3xi1> func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> { %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>