diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -11,6 +11,7 @@ include "mlir/IR/PatternBase.td" include "mlir/Dialect/Arith/IR/ArithOps.td" +include "mlir/Dialect/Vector/IR/VectorOps.td" // Create zero attribute of type matching the argument's type. def GetZeroAttr : NativeCodeCall<"$_builder.getZeroAttr($0.getType())">; @@ -35,6 +36,12 @@ (ConstantLikeMatcher APIntAttr:$c1)), (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>; +// addi(vector.broadcast(x), vector.broadcast(y)) -> vector.broadcast(addi(x, y)) +def AddIVectorBroadcast : + Pat<(Arith_AddIOp + (Vector_BroadcastOp $x), (Vector_BroadcastOp $y)), + (Vector_BroadcastOp (Arith_AddIOp $x, $y))>; + // addi(subi(x, c0), c1) -> addi(x, c1 - c0) def AddISubConstantRHS : Pat<(Arith_AddIOp:$res diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" @@ -258,8 +259,10 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns + .add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2417,3 +2417,16 @@ %r = arith.shrsi %x, %c0 : i64 return %r : i64 } + +// CHECK-LABEL: addWithBroadcast +// CHECK-SAME: %[[VAL_0:.*]]: index, +// CHECK-SAME: %[[VAL_1:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : index +// CHECK: %[[VAL_3:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x4xindex> +// CHECK: return %[[VAL_3]] : vector<1x4xindex> +func.func @addWithBroadcast( %arg1: index, %arg2: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg1 : index to vector<1x4xindex> + %1 = vector.broadcast %arg2 : index to vector<1x4xindex> + %2 = arith.addi %0, %1 : vector<1x4xindex> + return %2 : vector<1x4xindex> +} diff --git a/mlir/unittests/Conversion/PDLToPDLInterp/CMakeLists.txt b/mlir/unittests/Conversion/PDLToPDLInterp/CMakeLists.txt --- a/mlir/unittests/Conversion/PDLToPDLInterp/CMakeLists.txt +++ b/mlir/unittests/Conversion/PDLToPDLInterp/CMakeLists.txt @@ -3,6 +3,6 @@ ) target_link_libraries(MLIRPDLToPDLInterpTests PRIVATE - MLIRArithDialect + MLIRArithTransforms MLIRPDLToPDLInterp ) diff --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt --- a/mlir/unittests/Dialect/Transform/CMakeLists.txt +++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt @@ -3,6 +3,7 @@ ) target_link_libraries(MLIRTransformDialectTests PRIVATE + MLIRArithTransforms MLIRFuncDialect MLIRTransformDialect ) diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt --- a/mlir/unittests/Interfaces/CMakeLists.txt +++ b/mlir/unittests/Interfaces/CMakeLists.txt @@ -7,6 +7,7 @@ target_link_libraries(MLIRInterfacesTests PRIVATE + MLIRArithTransforms MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRDLTIDialect diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt --- a/mlir/unittests/Pass/CMakeLists.txt +++ b/mlir/unittests/Pass/CMakeLists.txt @@ -5,5 +5,6 @@ ) target_link_libraries(MLIRPassTests PRIVATE + MLIRArithTransforms MLIRFuncDialect MLIRPass)