diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4278,10 +4278,14 @@ result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp)); } -// Eliminates transpose operations, which produce values identical to their -// input values. This happens when the dimensions of the input vector remain in -// their original order after the transpose operation. OpFoldResult vector::TransposeOp::fold(ArrayRef operands) { + // Eliminate splat constant transpose ops. + if (auto attr = operands.front().dyn_cast_or_null()) + if (attr.isSplat()) + return attr.reshape(getResultType()); + + // Eliminate identity transpose ops. This happens when the dimensions of the + // input vector remain in their original order after the transpose operation. SmallVector transp; getTransp(transp); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1463,6 +1463,17 @@ // ----- +// CHECK-LABEL: func @transpose_splat_constant +// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32> +// CHECK: return %[[CST]] +func @transpose_splat_constant() -> vector<8x4xf32> { + %cst = arith.constant dense<5.0> : vector<4x8xf32> + %0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32> + return %0 : vector<8x4xf32> +} + +// ----- + // CHECK-LABEL: func @insert_element_fold // CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> // CHECK: return %[[V]] diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -273,14 +273,13 @@ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)> // CHECK-LABEL: func @transfer_read_permutations -func @transfer_read_permutations(%arg0 : memref, %arg1 : memref) +func @transfer_read_permutations(%arg0 : memref, %arg1 : memref, %m: i1) -> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<8xf32>) { // CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index - %m = arith.constant 1 : i1 %mask0 = vector.splat %m : vector<7x14xi1> %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref, vector<7x14x8x16xf32> @@ -331,10 +330,9 @@ // CHECK-SAME: %[[ARG1:.*]]: tensor func @transfer_write_permutations( %arg0 : memref, %arg1 : tensor, - %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> tensor { + %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>, %m: i1) -> tensor { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - %m = arith.constant 1 : i1 %mask0 = vector.splat %m : vector<7x14x8x16xi1> %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor