diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4388,11 +4388,42 @@ } }; +// Folds transpose(splat x : src_type) : res_type into splat x : res_type. +class FoldTransposeSplat final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + if (auto constantOp = + transposeOp.getVector().getDefiningOp()) { + auto dense = constantOp.getValue().dyn_cast(); + if (!dense) + return failure(); + Attribute newAttr = dense.getSplatValue(); + if (auto vecDstType = transposeOp.getType().dyn_cast()) + newAttr = DenseElementsAttr::get(vecDstType, newAttr); + rewriter.replaceOpWithNewOp(transposeOp, newAttr); + return success(); + } + if (auto splatOp = + transposeOp.getVector().getDefiningOp()) { + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResultType(), splatOp.getInput()); + return success(); + } + + return failure(); + } +}; + } // namespace void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } void vector::TransposeOp::getTransp(SmallVectorImpl &results) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1463,6 +1463,29 @@ // ----- +// CHECK-LABEL: func @transpose_splat1() -> vector<3x4xf32> { +// CHECK: %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<3x4xf32> +// CHECK: return %[[VAL_0]] : vector<3x4xf32> +// CHECK: } +func @transpose_splat1() -> vector<3x4xf32> { + %splat = arith.constant dense<1.0> : vector<4x3xf32> + %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} + +// CHECK-LABEL: func @transpose_splat2( +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32> +// CHECK: return %[[VAL_1]] : vector<3x4xf32> +// CHECK: } +func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { + %splat = vector.splat %arg : vector<4x3xf32> + %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} + +// ----- + // CHECK-LABEL: func @insert_element_fold // CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> // CHECK: return %[[V]] diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -276,6 +276,8 @@ func @transfer_read_permutations(%arg0 : memref, %arg1 : memref) -> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<8xf32>) { +// CHECK-DAG: %[[MASK0:.*]] = arith.constant dense : vector<14x7xi1> +// CHECK-DAG: %[[MASK1:.*]] = arith.constant dense : vector<16x14xi1> // CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 @@ -284,20 +286,17 @@ %mask0 = vector.splat %m : vector<7x14xi1> %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref, vector<7x14x8x16xf32> -// CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> // CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> %mask1 = vector.splat %m : vector<14x16xi1> %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref, vector<7x14x8x16xf32> -// CHECK: %[[MASK1:.*]] = vector.transpose {{.*}} : vector<14x16xi1> to vector<16x14xi1> // CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> %mask2 = vector.splat %m : vector<7x14xi1> %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> -// CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> -// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> +// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> // CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32> // CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32> @@ -332,15 +331,15 @@ func @transfer_write_permutations( %arg0 : memref, %arg1 : tensor, %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> tensor { + // CHECK-DAG: %[[MASK:.*]] = arith.constant dense : vector<8x14x16x7xi1> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index %m = arith.constant 1 : i1 %mask0 = vector.splat %m : vector<7x14x8x16xi1> %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor - // CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1> // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32> - // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor + // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>