diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -270,6 +270,7 @@ } }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; + let hasFolder = 1; } def Vector_ShuffleOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -929,6 +929,17 @@ return success(); } +OpFoldResult BroadcastOp::fold(ArrayRef operands) { + if (auto cst = source().getDefiningOp()) { + auto shapedType = getResult().getType().cast(); + if (cst.getType().isIntOrIndexOrFloat()) + return DenseElementsAttr::get(shapedType, cst.value()); + if (auto attr = cst.value().dyn_cast()) + return DenseElementsAttr::get(shapedType, attr.getSplatValue()); + } + return {}; +} + //===----------------------------------------------------------------------===// // ShuffleOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -13,13 +13,8 @@ } // CHECK-LABEL:func @matmul -// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> // CHECK: store {{.*}}[] : memref> -// -// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32> // CHECK: store {{.*}}[] : memref> -// -// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32> // CHECK: store {{.*}}[] : memref> // // CHECK: linalg.copy diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -385,3 +385,26 @@ %2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32> return %0, %2 : vector<4x8xf32>, vector<2xi32> } + +// ----- + +// CHECK-LABEL: broadcast_folding1 +// CHECK: constant +// CHECK-NOT: vector.broadcast +func @broadcast_folding1() -> vector<4xf32> { + %0 = constant 42.0 : f32 + %1 = vector.broadcast %0 : f32 to vector<4xf32> + return %1 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_folding2 +// CHECK: constant +// CHECK-NOT: vector.broadcast +func @broadcast_folding2() -> vector<4x16xf32> { + %0 = constant 42.0 : f32 + %1 = vector.broadcast %0 : f32 to vector<16xf32> + %2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32> + return %2 : vector<4x16xf32> +}