diff --git a/mlir/include/mlir/IR/AttributePropagationInterface.h b/mlir/include/mlir/IR/AttributePropagationInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/AttributePropagationInterface.h @@ -0,0 +1,46 @@ +//===- AttributePropagationInterface.h - Propagation Interface --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_IR_ATTRIBUTEPROPAGATIONINTERFACE_H_ +#define MLIR_IR_ATTRIBUTEPROPAGATIONINTERFACE_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/Identifier.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +class Attribute; +class OpFoldResult; +class Region; + +/// The AttributePropagationInterface allows dialects to provide a hook that is +/// invoked after operation creation to implicity add/propagate dialect +/// attributes (i.e. a NamedAttribute for which the key is prefixed with the +/// dialect name). +class AttributePropagationInterface + : public DialectInterface::Base { +public: + AttributePropagationInterface(Dialect *dialect) : Base(dialect) {} + + /// Invoked to allow the dialect to inspect the attr and the op and decide to + /// propagate or not this attribute on the operation. It is invoked by the + /// builder on op creation during rewrite. + virtual void propagateAttribute(NamedAttribute attr, Operation *op) const { + return; + } +}; + +/// Propagate the attributes from the provided Dictionnary attribute on an +/// operation by calling back into the dialect interface for each +/// NamedAttribute. +void propagateAttributes(DictionaryAttr propagatedAttr, Operation *op); + +} // end namespace mlir + +#endif // MLIR_IR_ATTRIBUTEPROPAGATIONINTERFACE_H_ diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -199,6 +199,7 @@ /// will cause subsequent insertions to go right before it. explicit OpBuilder(Operation *op, Listener *listener = nullptr) : OpBuilder(op->getContext(), listener) { + propagated_attrs = op->getAttrDictionary(); setInsertionPoint(op); } @@ -477,6 +478,16 @@ return cast(cloneWithoutRegions(*op.getOperation())); } + /// Return the current dictionnary of attributes considered for propagation + /// on operations created in this builder. + DictionaryAttr getPropagatedAttributes() { return propagated_attrs; } + + /// Set the dictionnary of attributes that will be considered for propagation + /// on operations created in this builder. + void setPropagatedAttributes(DictionaryAttr new_propagated_attrs) { + propagated_attrs = new_propagated_attrs; + } + private: /// The current block this builder is inserting into. Block *block = nullptr; @@ -485,6 +496,8 @@ Block::iterator insertPoint; /// The optional listener for events of this builder. Listener *listener; + /// Attributes considered for propagation on every created operation. + DictionaryAttr propagated_attrs; }; } // namespace mlir diff --git a/mlir/lib/IR/AttributePropagationInterface.cpp b/mlir/lib/IR/AttributePropagationInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/AttributePropagationInterface.cpp @@ -0,0 +1,29 @@ +//===- AttributePropagationInterface.cpp - Attribute Propagation --*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AttributePropagationInterface.h" +#include "mlir/IR/Dialect.h" + +using namespace mlir; + +void mlir::propagateAttributes(DictionaryAttr propagatedAttr, Operation *op) { + if (!propagatedAttr) + return; + for (NamedAttribute attr : propagatedAttr) { + MLIRContext *ctx = attr.second.getContext(); + auto dialectNamePair = attr.first.strref().split('.'); + if (dialectNamePair.second.empty()) + continue; + Dialect *dialect = ctx->getLoadedDialect(dialectNamePair.first); + if (!dialect) + continue; + if (auto *iface = + dialect->getRegisteredInterface()) + iface->propagateAttribute(attr, op); + } +} diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AttributePropagationInterface.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -395,7 +396,9 @@ /// Create an operation given the fields represented as an OperationState. Operation *OpBuilder::createOperation(const OperationState &state) { - return insert(Operation::create(state)); + Operation *newOp = Operation::create(state); + propagateAttributes(getPropagatedAttributes(), newOp); + return insert(newOp); } /// Attempts to fold the given operation and places new results within diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -160,6 +160,7 @@ // benefit, so if we match we can immediately rewrite. For PDL patterns, the // match has already been performed, we just need to rewrite. rewriter.setInsertionPoint(op); + rewriter.setPropagatedAttributes(op->getAttrDictionary()); LogicalResult result = success(); if (pdlMatch) { bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir --- a/mlir/test/Transforms/test-canonicalize.mlir +++ b/mlir/test/Transforms/test-canonicalize.mlir @@ -1,10 +1,10 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @remove_op_with_inner_ops_pattern func @remove_op_with_inner_ops_pattern() { // CHECK-NEXT: return "test.op_with_region_pattern"() ({ - "foo.op_with_region_terminator"() : () -> () + "test.op_with_region_terminator"() : () -> () }) : () -> () return } @@ -13,7 +13,7 @@ func @remove_op_with_inner_ops_fold_no_side_effect() { // CHECK-NEXT: return "test.op_with_region_fold_no_side_effect"() ({ - "foo.op_with_region_terminator"() : () -> () + "test.op_with_region_terminator"() : () -> () }) : () -> () return } @@ -23,7 +23,7 @@ func @remove_op_with_inner_ops_fold(%arg0 : i32) -> (i32) { // CHECK-NEXT: return %[[ARG_0]] %0 = "test.op_with_region_fold"(%arg0) ({ - "foo.op_with_region_terminator"() : () -> () + "test.op_with_region_terminator"() : () -> () }) : (i32) -> (i32) return %0 : i32 } @@ -51,3 +51,14 @@ // CHECK-NEXT: return %[[O0]], %[[O1]] return %y, %z: i32, i32 } + + +// CHECK-LABEL: func @propagate_attribute +func private @propagate_attribute_callee() +func @propagate_attribute() { + // CHECK-NOT test.foo = "bar" + // CHECK: call @propagate_attribute_callee() + // CHECK-SAME test.propagate = 42 : i32 + "test.fold_to_call_op"() {test.foo = "bar", test.propagate = 42 : i32, callee=@propagate_attribute_callee} : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -9,6 +9,7 @@ #include "TestDialect.h" #include "TestTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AttributePropagationInterface.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -91,6 +92,15 @@ } }; +struct TestDialectFoldAttributePropagationInterface + : public AttributePropagationInterface { + using AttributePropagationInterface::AttributePropagationInterface; + void propagateAttribute(NamedAttribute attr, Operation *op) const final { + if (attr.first == "test.propagate") + op->setAttr(attr.first, attr.second); + } +}; + /// This class defines the interface for handling inlining with standard /// operations. struct TestInlinerInterface : public DialectInlinerInterface { @@ -168,7 +178,8 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); - addInterfaces(); addTypes