Index: mlir/include/mlir/IR/Matchers.h =================================================================== --- mlir/include/mlir/IR/Matchers.h +++ mlir/include/mlir/IR/Matchers.h @@ -79,6 +79,28 @@ } }; +/// The matcher that matches a constant index operation. +struct constant_index_op_binder { + IntegerAttr *bind_value; + + /// Creates a matcher instance that binds the index + /// attribute value if match succeeds. + explicit constant_index_op_binder(IntegerAttr *bv) : bind_value(bv) {} + /// Create a matcher instance that doesn't bind if match succeeds. + constant_index_op_binder() : bind_value(nullptr) {} + + bool match(Operation *op) { + IntegerAttr attr; + if (!constant_op_binder(&attr).match(op)) + return false; + if (!op->getResult(0).getType().isIndex()) + return false; + if (bind_value) + *bind_value = attr; + return true; + } +}; + /// The matcher that matches a constant scalar / vector splat / tensor splat /// integer operation and binds the constant integer value. struct constant_int_op_binder { @@ -213,6 +235,18 @@ return detail::constant_op_binder(bind_value); } +/// Matches a constant index operation. +inline detail::constant_index_op_binder m_ConstantIndex() { + return detail::constant_index_op_binder(); +} + +/// Matches a constant index operation and +/// writes the integer value to bind_value. +inline detail::constant_index_op_binder +m_ConstantIndex(IntegerAttr *bind_value) { + return detail::constant_index_op_binder(bind_value); +} + /// Matches a constant scalar / vector splat / tensor splat integer one. inline detail::constant_int_value_matcher<1> m_One() { return detail::constant_int_value_matcher<1>(); Index: mlir/lib/Dialect/StandardOps/Ops.cpp =================================================================== --- mlir/lib/Dialect/StandardOps/Ops.cpp +++ mlir/lib/Dialect/StandardOps/Ops.cpp @@ -192,13 +192,6 @@ return success(); } -/// Matches a ConstantIndexOp. -/// TODO: This should probably just be a general matcher that uses m_Constant -/// and checks the operation for an index type. -static detail::op_matcher m_ConstantIndex() { - return detail::op_matcher(); -} - //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// Index: mlir/test/IR/test-matchers.mlir =================================================================== --- mlir/test/IR/test-matchers.mlir +++ mlir/test/IR/test-matchers.mlir @@ -41,3 +41,13 @@ // CHECK-LABEL: test2 // CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00 // CHECK: Pattern add(add(a, constant), a) matched + +func @test3(%a: index) -> index { + %0 = constant 23: index + %1 = muli %0, %a: index + return %1: index +} + +// CHECK-LABEL: test3 +// CHECK: Pattern mul(a, constantIndex) matched and bound constant to: 23 +// CHECK: Pattern mul(a, constant) matched and bound constant to: 23 Index: mlir/test/lib/IR/TestMatchers.cpp =================================================================== --- mlir/test/lib/IR/TestMatchers.cpp +++ mlir/test/lib/IR/TestMatchers.cpp @@ -137,6 +137,24 @@ llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; } +void test3(FuncOp f) { + assert(f.getNumArguments() == 1 && "matcher test funcs must have 1 args"); + auto a = m_Val(f.getArgument(0)); + IntegerAttr integerAttrConstantIndex; + IntegerAttr integerAttrConstant; + auto mConstantIndex = + m_Op(m_ConstantIndex(&integerAttrConstantIndex), a); + Operation *lastOp = f.getBody().front().back().getPrevNode(); + if (mConstantIndex.match(lastOp)) + llvm::outs() + << "Pattern mul(a, constantIndex) matched and bound constant to: " + << integerAttrConstantIndex.getInt() << "\n"; + auto mConstant = m_Op(m_Constant(&integerAttrConstant), a); + if (mConstant.match(lastOp)) + llvm::outs() << "Pattern mul(a, constant) matched and bound constant to: " + << integerAttrConstant.getInt() << "\n"; +} + void TestMatchers::runOnFunction() { auto f = getFunction(); llvm::outs() << f.getName() << "\n"; @@ -144,6 +162,8 @@ test1(f); if (f.getName() == "test2") test2(f); + if (f.getName() == "test3") + test3(f); } static PassRegistration pass("test-matchers",