diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -270,6 +270,7 @@ } }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; + let hasFolder = 1; } def Vector_ShuffleOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -929,6 +929,17 @@ return success(); } +OpFoldResult BroadcastOp::fold(ArrayRef operands) { + if (!operands[0]) + return {}; + auto shapedType = getResult().getType().cast(); + if (operands[0].getType().isIntOrIndexOrFloat()) + return DenseElementsAttr::get(shapedType, operands[0]); + if (auto attr = operands[0].dyn_cast()) + return DenseElementsAttr::get(shapedType, attr.getSplatValue()); + return {}; +} + //===----------------------------------------------------------------------===// // ShuffleOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -13,13 +13,8 @@ } // CHECK-LABEL:func @matmul -// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> // CHECK: store {{.*}}[] : memref> -// -// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32> // CHECK: store {{.*}}[] : memref> -// -// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32> // CHECK: store {{.*}}[] : memref> // // CHECK: linalg.copy diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -385,3 +385,28 @@ %2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32> return %0, %2 : vector<4x8xf32>, vector<2xi32> } + +// ----- + +// CHECK-LABEL: broadcast_folding1 +// CHECK: %[[CST:.*]] = constant dense<42> : vector<4xi32> +// CHECK-NOT: vector.broadcast +// CHECK: return %[[CST]] +func @broadcast_folding1() -> vector<4xi32> { + %0 = constant 42 : i32 + %1 = vector.broadcast %0 : i32 to vector<4xi32> + return %1 : vector<4xi32> +} + +// ----- + +// CHECK-LABEL: @broadcast_folding2 +// CHECK: %[[CST:.*]] = constant dense<42> : vector<4x16xi32> +// CHECK-NOT: vector.broadcast +// CHECK: return %[[CST]] +func @broadcast_folding2() -> vector<4x16xi32> { + %0 = constant 42 : i32 + %1 = vector.broadcast %0 : i32 to vector<16xi32> + %2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32> + return %2 : vector<4x16xi32> +}