diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -270,6 +270,7 @@ } }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; + let hasFolder = 1; } def Vector_ShuffleOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -929,6 +929,14 @@ return success(); } +OpFoldResult BroadcastOp::fold(ArrayRef operands) { + if (source().getType().isIntOrIndexOrFloat()) + if (auto cst = source().getDefiningOp()) + return DenseElementsAttr::get(getResult().getType().cast(), + cst.value()); + return {}; +} + //===----------------------------------------------------------------------===// // ShuffleOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -13,13 +13,8 @@ } // CHECK-LABEL:func @matmul -// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> // CHECK: store {{.*}}[] : memref> -// -// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32> // CHECK: store {{.*}}[] : memref> -// -// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32> // CHECK: store {{.*}}[] : memref> // // CHECK: linalg.copy diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -385,3 +385,14 @@ %2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32> return %0, %2 : vector<4x8xf32>, vector<2xi32> } + +// ----- + +// CHECK-LABEL: broadcast_folding1 +// CHECK: constant +// CHECK-NOT: vector.broadcast +func @broadcast_folding1() -> vector<4xf32> { + %0 = constant 0.0 : f32 + %1 = vector.broadcast %0 : f32 to vector<4xf32> + return %1 : vector<4xf32> +}