diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -120,6 +120,7 @@ IntegerAttr getUI32IntegerAttr(uint32_t value); /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty. + DenseIntElementsAttr getBoolVectorAttr(ArrayRef values); DenseIntElementsAttr getI32VectorAttr(ArrayRef values); DenseIntElementsAttr getI64VectorAttr(ArrayRef values); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1311,7 +1311,8 @@ /// %4 = vector.insert %l, %z[0] /// .. /// %x = vector.insert %l, %..[a-1] -/// which will be folded at LLVM IR level. +/// until a one-dimensional vector is reached. All these operations +/// will be folded at LLVM IR level. class ConstantMaskOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1325,20 +1326,22 @@ int64_t rank = dimSizes.size(); int64_t trueDim = dimSizes[0].cast().getInt(); - Value trueVal; if (rank == 1) { - trueVal = rewriter.create( - loc, eltType, rewriter.getIntegerAttr(eltType, 1)); - } else { - VectorType lowType = - VectorType::get(dstType.getShape().drop_front(), eltType); - SmallVector newDimSizes; - for (int64_t r = 1; r < rank; r++) - newDimSizes.push_back(dimSizes[r].cast().getInt()); - trueVal = rewriter.create( - loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); + SmallVector values(dstType.getDimSize(0)); + for (int64_t d = 0; d < trueDim; d++) + values[d] = true; + rewriter.replaceOpWithNewOp( + op, dstType, rewriter.getBoolVectorAttr(values)); + return success(); } + VectorType lowType = + VectorType::get(dstType.getShape().drop_front(), eltType); + SmallVector newDimSizes; + for (int64_t r = 1; r < rank; r++) + newDimSizes.push_back(dimSizes[r].cast().getInt()); + Value trueVal = rewriter.create( + loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); Value result = rewriter.create(loc, dstType, rewriter.getZeroAttr(dstType)); for (int64_t d = 0; d < trueDim; d++) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -104,6 +104,12 @@ return IntegerAttr::get(getIntegerType(64), APInt(64, value)); } +DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef values) { + return DenseIntElementsAttr::get( + VectorType::get(static_cast(values.size()), getI1Type()), + values); +} + DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef values) { return DenseIntElementsAttr::get( VectorType::get(static_cast(values.size()), getIntegerType(32)), diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -944,17 +944,26 @@ return %0 : vector<8xi1> } // CHECK-LABEL: func @genbool_1d -// CHECK: %[[T0:.*]] = llvm.mlir.constant(true) : !llvm.i1 -// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense : vector<8xi1>) : !llvm<"<8 x i1>"> -// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 -// CHECK: %[[T3:.*]] = llvm.insertelement %[[T0]], %[[T1]][%[[T2]] : !llvm.i64] : !llvm<"<8 x i1>"> -// CHECK: %[[T4:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 -// CHECK: %[[T5:.*]] = llvm.insertelement %[[T0]], %[[T3]][%[[T4]] : !llvm.i64] : !llvm<"<8 x i1>"> -// CHECK: %[[T6:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 -// CHECK: %[[T7:.*]] = llvm.insertelement %[[T0]], %[[T5]][%[[T6]] : !llvm.i64] : !llvm<"<8 x i1>"> -// CHECK: %[[T8:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64 -// CHECK: %[[T9:.*]] = llvm.insertelement %[[T0]], %[[T7]][%[[T8]] : !llvm.i64] : !llvm<"<8 x i1>"> -// CHECK: llvm.return %9 : !llvm<"<8 x i1>"> +// CHECK: %[[C1:.*]] = llvm.mlir.constant(dense<[true, true, true, true, false, false, false, false]> : vector<8xi1>) : !llvm<"<8 x i1>"> +// CHECK: llvm.return %[[C1]] : !llvm<"<8 x i1>"> + +func @genbool_2d() -> vector<4x4xi1> { + %v = vector.constant_mask [2, 2] : vector<4x4xi1> + return %v: vector<4x4xi1> +} + +// CHECK-LABEL: func @genbool_2d +// CHECK: %[[C1:.*]] = llvm.mlir.constant(dense<[true, true, false, false]> : vector<4xi1>) : !llvm<"<4 x i1>"> +// CHECK: %[[C2:.*]] = llvm.mlir.constant(dense : vector<4x4xi1>) : !llvm<"[4 x <4 x i1>]"> +// CHECK: %[[T0:.*]] = llvm.insertvalue %[[C1]], %[[C2]][0] : !llvm<"[4 x <4 x i1>]"> +// CHECK: %[[T1:.*]] = llvm.insertvalue %[[C1]], %[[T0]][1] : !llvm<"[4 x <4 x i1>]"> +// CHECK: llvm.return %[[T1]] : !llvm<"[4 x <4 x i1>]"> + +func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } + : vector<16xf32> -> vector<16xf32> + return %0 : vector<16xf32> +} // CHECK-LABEL: func @flat_transpose // CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>"> @@ -962,8 +971,3 @@ // CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} : // CHECK-SAME: !llvm<"<16 x float>"> into !llvm<"<16 x float>"> // CHECK: llvm.return %[[T]] : !llvm<"<16 x float>"> -func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { - %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } - : vector<16xf32> -> vector<16xf32> - return %0 : vector<16xf32> -} diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -676,13 +676,8 @@ } // CHECK-LABEL: func @genbool_1d -// CHECK: %[[TT:.*]] = constant true -// CHECK: %[[C1:.*]] = constant dense : vector<8xi1> -// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<8xi1> -// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<8xi1> -// CHECK: %[[T2:.*]] = vector.insert %[[TT]], %[[T1]] [2] : i1 into vector<8xi1> -// CHECK: %[[T3:.*]] = vector.insert %[[TT]], %[[T2]] [3] : i1 into vector<8xi1> -// CHECK: return %[[T3]] : vector<8xi1> +// CHECK: %[[T0:.*]] = constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1> +// CHECK: return %[[T0]] : vector<8xi1> func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1> @@ -690,14 +685,11 @@ } // CHECK-LABEL: func @genbool_2d -// CHECK: %[[TT:.*]] = constant true -// CHECK: %[[C1:.*]] = constant dense : vector<4xi1> +// CHECK: %[[C1:.*]] = constant dense<[true, true, false, false]> : vector<4xi1> // CHECK: %[[C2:.*]] = constant dense : vector<4x4xi1> -// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<4xi1> -// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<4xi1> -// CHECK: %[[T2:.*]] = vector.insert %[[T1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> -// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[T2]] [1] : vector<4xi1> into vector<4x4xi1> -// CHECK: return %[[T3]] : vector<4x4xi1> +// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> +// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1> +// CHECK: return %[[T1]] : vector<4x4xi1> func @genbool_2d() -> vector<4x4xi1> { %v = vector.constant_mask [2, 2] : vector<4x4xi1> @@ -705,16 +697,12 @@ } // CHECK-LABEL: func @genbool_3d -// CHECK: %[[TT:.*]] = constant true -// CHECK: %[[C1:.*]] = constant dense : vector<4xi1> +// CHECK: %[[C1:.*]] = constant dense<[true, true, true, false]> : vector<4xi1> // CHECK: %[[C2:.*]] = constant dense : vector<3x4xi1> // CHECK: %[[C3:.*]] = constant dense : vector<2x3x4xi1> -// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<4xi1> -// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<4xi1> -// CHECK: %[[T2:.*]] = vector.insert %[[TT]], %[[T1]] [2] : i1 into vector<4xi1> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> -// CHECK: return %[[T4]] : vector<2x3x4xi1> +// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> +// CHECK: return %[[T1]] : vector<2x3x4xi1> func @genbool_3d() -> vector<2x3x4xi1> { %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>