diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -270,6 +270,7 @@ } }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; + let hasFolder = 1; } def Vector_ShuffleOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -929,6 +929,20 @@ return success(); } +OpFoldResult BroadcastOp::fold(ArrayRef operands) { + if (source().getType().isIntOrIndexOrFloat()) { + if (auto cst = source().getDefiningOp()) { + return DenseElementsAttr::get(getResult().getType().cast(), + cst.value()); + } + } + if (auto prevOp = source().getDefiningOp()) { + setOperand(prevOp.source()); + return getResult(); + } + return {}; +} + //===----------------------------------------------------------------------===// // ShuffleOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -385,3 +385,26 @@ %2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32> return %0, %2 : vector<4x8xf32>, vector<2xi32> } + +// ----- + +// CHECK-LABEL: broadcast_folding1 +// CHECK: constant +// CHECK-NOT: vector.broadcast +func @broadcast_folding1() -> vector<4xf32> { + %0 = constant 0.0 : f32 + %1 = vector.broadcast %0 : f32 to vector<4xf32> + return %1 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_folding2 +// CHECK-SAME: %[[ARG0:.*]]: f32 +// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG0]] : f32 to vector<4x16xf32> +// CHECK: return %[[RES]] +func @broadcast_folding2(%arg0: f32) -> vector<4x16xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<16xf32> + %1 = vector.broadcast %0 : vector<16xf32> to vector<4x16xf32> + return %1 : vector<4x16xf32> +}