diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td b/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td --- a/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td @@ -110,9 +110,20 @@ llvm::erase_if(effects, [&](auto &it) { return it.getValue() != value; }); } + /// Collect all of the effect instances that operate on the provided symbol + /// reference and place them in 'effects'. + void getEffectsOnSymbol(::mlir::SymbolRefAttr value, + llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance< + }] # baseEffect # [{>> & effects) { + getEffects(effects); + llvm::erase_if(effects, [&](auto &it) { + return it.getSymbolRef() != value; + }); + } + /// Collect all of the effect instances that operate on the provided /// resource and place them in 'effects'. - void getEffectsOnValue(::mlir::SideEffects::Resource *resource, + void getEffectsOnResource(::mlir::SideEffects::Resource *resource, llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance< }] # baseEffect # [{>> & effects) { getEffects(effects); diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h @@ -131,9 +131,9 @@ /// This class represents a specific instance of an effect. It contains the /// effect being applied, a resource that corresponds to where the effect is -/// applied, an optional value (either operand, result, or region entry -/// argument) that the effect is applied to, and an optional parameters -/// attribute further specifying the details of the effect. +/// applied, and an optional symbol reference or value(either operand, result, +/// or region entry argument) that the effect is applied to, and an optional +/// parameters attribute further specifying the details of the effect. template class EffectInstance { public: EffectInstance(EffectT *effect, Resource *resource = DefaultResource::get()) @@ -141,6 +141,9 @@ EffectInstance(EffectT *effect, Value value, Resource *resource = DefaultResource::get()) : effect(effect), resource(resource), value(value) {} + EffectInstance(EffectT *effect, SymbolRefAttr symbol, + Resource *resource = DefaultResource::get()) + : effect(effect), resource(resource), value(symbol) {} EffectInstance(EffectT *effect, Attribute parameters, Resource *resource = DefaultResource::get()) : effect(effect), resource(resource), parameters(parameters) {} @@ -148,13 +151,23 @@ Resource *resource = DefaultResource::get()) : effect(effect), resource(resource), value(value), parameters(parameters) {} + EffectInstance(EffectT *effect, SymbolRefAttr symbol, Attribute parameters, + Resource *resource = DefaultResource::get()) + : effect(effect), resource(resource), value(symbol), + parameters(parameters) {} /// Return the effect being applied. EffectT *getEffect() const { return effect; } /// Return the value the effect is applied on, or nullptr if there isn't a /// known value being affected. - Value getValue() const { return value; } + Value getValue() const { return value ? value.dyn_cast() : Value(); } + + /// Return the symbol reference the effect is applied on, or nullptr if there + /// isn't a known smbol being affected. + SymbolRefAttr getSymbolRef() const { + return value ? value.dyn_cast() : SymbolRefAttr(); + } /// Return the resource that the effect applies to. Resource *getResource() const { return resource; } @@ -169,8 +182,8 @@ /// The resource that the given value resides in. Resource *resource; - /// The value that the effect applies to. This is optionally null. - Value value; + /// The Symbol or Value that the effect applies to. This is optionally null. + PointerUnion value; /// Additional parameters of the effect instance. An attribute is used for /// type-safe structured storage and context-based uniquing. Concrete effects diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -94,6 +94,10 @@ // of `TypeAttrBase`). bool isTypeAttr() const; + // Returns true if this attribute is a symbol reference attribute (i.e., a + // subclass of `SymbolRefAttr` or `FlatSymbolRefAttr`). + bool isSymbolRefAttr() const; + // Returns true if this attribute is an enum attribute (i.e., a subclass of // `EnumAttrInfo`) bool isEnumAttr() const; diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -55,6 +55,13 @@ bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); } +bool Attribute::isSymbolRefAttr() const { + StringRef defName = def->getName(); + if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr") + return true; + return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr"); +} + bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); } StringRef Attribute::getStorageType() const { diff --git a/mlir/test/IR/test-side-effects.mlir b/mlir/test/IR/test-side-effects.mlir --- a/mlir/test/IR/test-side-effects.mlir +++ b/mlir/test/IR/test-side-effects.mlir @@ -19,6 +19,11 @@ {effect="allocate", on_result, test_resource} ]} : () -> i32 +// expected-remark@+1 {{found an instance of 'read' on a symbol '@foo_ref', on resource ''}} +"test.side_effect_op"() {effects = [ + {effect="read", on_reference = @foo_ref, test_resource} +]} : () -> i32 + // No _memory_ effects, but a parametric test effect. // expected-remark@+2 {{operation has no memory effects}} // expected-remark@+1 {{found a parametric effect with affine_map<(d0, d1) -> (d1, d0)>}} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -744,17 +744,18 @@ .Case("read", MemoryEffects::Read::get()) .Case("write", MemoryEffects::Write::get()); - // Check for a result to affect. - Value value; - if (effectElement.get("on_result")) - value = getResult(); - // Check for a non-default resource to use. SideEffects::Resource *resource = SideEffects::DefaultResource::get(); if (effectElement.get("test_resource")) resource = TestResource::get(); - effects.emplace_back(effect, value, resource); + // Check for a result to affect. + if (effectElement.get("on_result")) + effects.emplace_back(effect, getResult(), resource); + else if (Attribute ref = effectElement.get("on_reference")) + effects.emplace_back(effect, ref.cast(), resource); + else + effects.emplace_back(effect, resource); } } diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp --- a/mlir/test/lib/IR/TestSideEffects.cpp +++ b/mlir/test/lib/IR/TestSideEffects.cpp @@ -43,6 +43,8 @@ if (instance.getValue()) diag << " on a value,"; + else if (SymbolRefAttr symbolRef = instance.getSymbolRef()) + diag << " on a symbol '" << symbolRef << "',"; diag << " on resource '" << instance.getResource()->getName() << "'"; } diff --git a/mlir/test/mlir-tblgen/op-side-effects.td b/mlir/test/mlir-tblgen/op-side-effects.td --- a/mlir/test/mlir-tblgen/op-side-effects.td +++ b/mlir/test/mlir-tblgen/op-side-effects.td @@ -11,7 +11,12 @@ def CustomResource : Resource<"CustomResource">; def SideEffectOpA : TEST_Op<"side_effect_op_a"> { - let arguments = (ins Arg, "", [MemRead]>); + let arguments = (ins + Arg, "", [MemRead]>, + Arg:$symbol, + Arg:$flat_symbol, + Arg, "", [MemRead]>:$optional_symbol + ); let results = (outs Res]>); } @@ -21,6 +26,10 @@ // CHECK: void SideEffectOpA::getEffects // CHECK: for (::mlir::Value value : getODSOperands(0)) // CHECK: effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get()); +// CHECK: effects.emplace_back(MemoryEffects::Read::get(), symbol(), ::mlir::SideEffects::DefaultResource::get()); +// CHECK: effects.emplace_back(MemoryEffects::Write::get(), flat_symbol(), ::mlir::SideEffects::DefaultResource::get()); +// CHECK: if (auto symbolRef = optional_symbolAttr()) +// CHECK: effects.emplace_back(MemoryEffects::Read::get(), symbolRef, ::mlir::SideEffects::DefaultResource::get()); // CHECK: for (::mlir::Value value : getODSResults(0)) // CHECK: effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get()); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1627,12 +1627,12 @@ } void OpEmitter::genSideEffectInterfaceMethods() { - enum EffectKind { Operand, Result, Static }; + enum EffectKind { Operand, Result, Symbol, Static }; struct EffectLocation { /// The effect applied. SideEffect effect; - /// The index if the kind is either operand or result. + /// The index if the kind is not static. unsigned index : 30; /// The kind of the location. @@ -1661,17 +1661,29 @@ effects.push_back(EffectLocation{cast(decorator), /*index=*/0, EffectKind::Static}); } - /// Operands. + /// Attributes and Operands. for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) { - if (op.getArg(i).is()) { + Argument arg = op.getArg(i); + if (arg.is()) { resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand); ++operandIt; + continue; } + const NamedAttribute *attr = arg.get(); + if (attr->attr.getBaseAttr().isSymbolRefAttr()) + resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol); } /// Results. for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result); + // The code used to add an effect instance. + // {0}: The effect class. + // {1}: Optional value or symbol reference. + // {1}: The resource class. + const char *addEffectCode = + " effects.emplace_back({0}::get(), {1}{2}::get());\n"; + for (auto &it : interfaceEffects) { // Generate the 'getEffects' method. std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::" @@ -1684,19 +1696,30 @@ // Add effect instances for each of the locations marked on the operation. for (auto &location : it.second) { - if (location.kind != EffectKind::Static) { + StringRef effect = location.effect.getName(); + StringRef resource = location.effect.getResource(); + if (location.kind == EffectKind::Static) { + // A static instance has no attached value. + body << llvm::formatv(addEffectCode, effect, "", resource).str(); + } else if (location.kind == EffectKind::Symbol) { + // A symbol reference requires adding the proper attribute. + const auto *attr = op.getArg(location.index).get(); + if (attr->attr.isOptional()) { + body << " if (auto symbolRef = " << attr->name << "Attr())\n " + << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource) + .str(); + } else { + body << llvm::formatv(addEffectCode, effect, attr->name + "(), ", + resource) + .str(); + } + } else { + // Otherwise this is an operand/result, so we need to attach the Value. body << " for (::mlir::Value value : getODS" << (location.kind == EffectKind::Operand ? "Operands" : "Results") - << "(" << location.index << "))\n "; + << "(" << location.index << "))\n " + << llvm::formatv(addEffectCode, effect, "value, ", resource).str(); } - - body << " effects.emplace_back(" << location.effect.getName() - << "::get()"; - - // If the effect isn't static, it has a specific value attached to it. - if (location.kind != EffectKind::Static) - body << ", value"; - body << ", " << location.effect.getResource() << "::get());\n"; } } }