diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4088,7 +4088,7 @@ } OpFoldResult ShapeCastOp::fold(ArrayRef operands) { - // Nop shape cast. + // No-op shape cast. if (getSource().getType() == getResult().getType()) return getSource(); @@ -4113,6 +4113,13 @@ setOperand(otherOp.getSource()); return getResult(); } + + // Cancelling broadcast and shape cast ops. + if (auto bcastOp = getSource().getDefiningOp()) { + if (bcastOp.getSourceType() == getType()) + return bcastOp.getSource(); + } + return {}; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -635,6 +635,39 @@ // ----- +// CHECK-LABEL: func @fold_broadcast_shapecast +// CHECK-SAME: (%[[V:.+]]: vector<4xf32>) +// CHECK: return %[[V]] +func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32> + %1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<4xf32> + return %1 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @dont_fold_broadcast_shapecast_scalar +// CHECK: vector.broadcast +// CHECK: vector.shape_cast +func @dont_fold_broadcast_shapecast_scalar(%arg0: f32) -> vector<1xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1x1x1xf32> + %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1xf32> + return %1 : vector<1xf32> +} + +// ----- + +// CHECK-LABEL: func @dont_fold_broadcast_shapecast_diff_shape +// CHECK: vector.broadcast +// CHECK: vector.shape_cast +func @dont_fold_broadcast_shapecast_diff_shape(%arg0: vector<4xf32>) -> vector<8xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x2x4xf32> + %1 = vector.shape_cast %0 : vector<1x2x4xf32> to vector<8xf32> + return %1 : vector<8xf32> +} + +// ----- + // CHECK-LABEL: fold_vector_transfers func.func @fold_vector_transfers(%A: memref) -> (vector<4x8xf32>, vector<4x9xf32>) { %c0 = arith.constant 0 : index