diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -15,6 +15,7 @@ include "mlir/Analysis/CallInterfaces.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SideEffects.td" def Std_Dialect : Dialect { let name = "std"; @@ -893,7 +894,9 @@ %3 = load %0[%1, %1] : memref<4x4xi32> }]; - let arguments = (ins AnyMemRef:$memref, Variadic:$indices); + let arguments = (ins Arg<"the reference to load from", AnyMemRef, + [MemRead]>:$memref, + Variadic:$indices); let results = (outs AnyType:$result); let builders = [OpBuilder< @@ -1313,8 +1316,10 @@ store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0> }]; - let arguments = (ins AnyType:$value, AnyMemRef:$memref, - Variadic:$indices); + let arguments = (ins AnyType:$value, + Arg<"the reference to store to", AnyMemRef, + [MemWrite]>:$memref, + Variadic:$indices); let builders = [OpBuilder< "Builder *, OperationState &result, Value valueToStore, Value memref", [{ @@ -1580,7 +1585,8 @@ %12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0> }]; - let arguments = (ins AnyMemRef:$memref); + let arguments = (ins Arg<"the reference to load from", AnyMemRef, + [MemRead]>:$memref); let results = (outs AnyTensor:$result); // TensorLoadOp is fully verified by traits. let verifier = ?; @@ -1620,7 +1626,9 @@ tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0> }]; - let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref); + let arguments = (ins AnyTensor:$tensor, + Arg<"the reference to store to", AnyMemRef, + [MemWrite]>:$memref); // TensorStoreOp is fully verified by traits. let verifier = ?; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1588,6 +1588,29 @@ code body = b; } +// A base decorator class that may optionally be added to OpVariables. +class OpVariableDecorator; + +// Class for providing additional information on the variables, i.e. arguments +// and results, of an operation. +class OpVariable varDecorators = []> { + // A description for the argument. + string description = desc; + + // The constraint, either attribute or type, of the argument. + Constraint constraint = varConstraint; + + // The list of decorators for this variable, e.g. side effects. + list decorators = varDecorators; +} +class Arg decorators = []> + : OpVariable; +class Res decorators = []> + : OpVariable; + // Base class for all ops. class Op props = []> { // The dialect of the op. diff --git a/mlir/include/mlir/IR/SideEffects.td b/mlir/include/mlir/IR/SideEffects.td --- a/mlir/include/mlir/IR/SideEffects.td +++ b/mlir/include/mlir/IR/SideEffects.td @@ -107,7 +107,7 @@ // This class is the general base side effect class. This is used by derived // effect interfaces to define their effects. class SideEffect { + string resourceName> : OpVariableDecorator { /// The parent interface that the effect belongs to. string interfaceTrait = interface.trait; diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -57,6 +57,34 @@ // Returns this op's C++ class name prefixed with namespaces. std::string getQualCppClassName() const; + /// A class used to represent the decorators of an operator variable, i.e. + /// argument or result. + struct VariableDecorator { + public: + explicit VariableDecorator(const llvm::Record *def) : def(def) {} + const llvm::Record &getDef() const { return *def; } + + protected: + // The TableGen definition of this decorator. + const llvm::Record *def; + }; + + // A utility iterator over a list of variable decorators. + struct VariableDecoratorIterator + : public llvm::mapped_iterator { + using reference = VariableDecorator; + + /// Initializes the iterator to the specified iterator. + VariableDecoratorIterator(llvm::Init *const *it) + : llvm::mapped_iterator(it, + &unwrap) {} + static VariableDecorator unwrap(llvm::Init *init); + }; + using var_decorator_iterator = VariableDecoratorIterator; + using var_decorator_range = llvm::iterator_range; + using value_iterator = NamedTypeConstraint *; using value_range = llvm::iterator_range; @@ -84,6 +112,8 @@ TypeConstraint getResultTypeConstraint(int index) const; // Returns the `index`-th result's name. StringRef getResultName(int index) const; + // Returns the `index`-th result's decorators. + var_decorator_range getResultDecorators(int index) const; // Returns the number of variadic results in this operation. unsigned getNumVariadicResults() const; @@ -128,6 +158,7 @@ // Op argument (attribute or operand) accessors. Argument getArg(int index) const; StringRef getArgName(int index) const; + var_decorator_range getArgDecorators(int index) const; // Returns the trait wrapper for the given MLIR C++ `trait`. // TODO: We should add a C++ wrapper class for TableGen OpTrait instead of diff --git a/mlir/include/mlir/TableGen/SideEffects.h b/mlir/include/mlir/TableGen/SideEffects.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/SideEffects.h @@ -0,0 +1,55 @@ +//===- SideEffects.h - Side Effects classes ---------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Wrapper around side effect related classes defined in TableGen. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_SIDEEFFECTS_H_ +#define MLIR_TABLEGEN_SIDEEFFECTS_H_ + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Operator.h" + +namespace mlir { +namespace tblgen { + +// This class represents a specific instance of an effect that is being +// exhibited. +class SideEffect : public Operator::VariableDecorator { +public: + // Return the name of the C++ effect. + StringRef getName() const; + + // Return the name of the base C++ effect. + StringRef getBaseName() const; + + // Return the name of the parent interface trait. + StringRef getInterfaceTrait() const; + + // Return the name of the resource class. + StringRef getResource() const; + + static bool classof(const Operator::VariableDecorator *var); +}; + +// This class represents an instance of a side effect interface applied to an +// operation. This is a wrapper around an OpInterfaceTrait that also includes +// the effects that are applied. +class SideEffectTrait : public InterfaceOpTrait { +public: + // Return the effects that are attached to the side effect interface. + Operator::var_decorator_range getEffects() const; + + static bool classof(const OpTrait *t); +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_SIDEEFFECTS_H_ diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -10,6 +10,7 @@ OpTrait.cpp Pattern.cpp Predicate.cpp + SideEffects.cpp Successor.cpp Type.cpp diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -108,6 +108,15 @@ return results->getArgNameStr(index); } +auto tblgen::Operator::getResultDecorators(int index) const + -> var_decorator_range { + Record *result = + cast(def.getValueAsDag("results")->getArg(index))->getDef(); + if (!result->isSubClassOf("OpVariable")) + return var_decorator_range(nullptr, nullptr); + return *result->getValueAsListInit("decorators"); +} + unsigned tblgen::Operator::getNumVariadicResults() const { return std::count_if( results.begin(), results.end(), @@ -137,6 +146,15 @@ return argumentValues->getArgName(index)->getValue(); } +auto tblgen::Operator::getArgDecorators(int index) const + -> var_decorator_range { + Record *arg = + cast(def.getValueAsDag("arguments")->getArg(index))->getDef(); + if (!arg->isSubClassOf("OpVariable")) + return var_decorator_range(nullptr, nullptr); + return *arg->getValueAsListInit("decorators"); +} + const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const { for (const auto &t : traits) { if (auto opTrait = dyn_cast(&t)) { @@ -225,6 +243,7 @@ auto typeConstraintClass = recordKeeper.getClass("TypeConstraint"); auto attrClass = recordKeeper.getClass("Attr"); auto derivedAttrClass = recordKeeper.getClass("DerivedAttr"); + auto opVarClass = recordKeeper.getClass("OpVariable"); numNativeAttributes = 0; DagInit *argumentValues = def.getValueAsDag("arguments"); @@ -239,10 +258,12 @@ PrintFatalError(def.getLoc(), Twine("undefined type for argument #") + Twine(i)); Record *argDef = argDefInit->getDef(); + if (argDef->isSubClassOf(opVarClass)) + argDef = argDef->getValueAsDef("constraint"); if (argDef->isSubClassOf(typeConstraintClass)) { operands.push_back( - NamedTypeConstraint{givenName, TypeConstraint(argDefInit)}); + NamedTypeConstraint{givenName, TypeConstraint(argDef)}); } else if (argDef->isSubClassOf(attrClass)) { if (givenName.empty()) PrintFatalError(argDef->getLoc(), "attributes must be named"); @@ -284,6 +305,8 @@ int operandIndex = 0, attrIndex = 0; for (unsigned i = 0; i != numArgs; ++i) { Record *argDef = dyn_cast(argumentValues->getArg(i))->getDef(); + if (argDef->isSubClassOf(opVarClass)) + argDef = argDef->getValueAsDef("constraint"); if (argDef->isSubClassOf(typeConstraintClass)) { arguments.emplace_back(&operands[operandIndex++]); @@ -302,11 +325,14 @@ // Handle results. for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { auto name = resultsDag->getArgNameStr(i); - auto *resultDef = dyn_cast(resultsDag->getArg(i)); - if (!resultDef) { + auto *resultInit = dyn_cast(resultsDag->getArg(i)); + if (!resultInit) { PrintFatalError(def.getLoc(), Twine("undefined type for result #") + Twine(i)); } + auto *resultDef = resultInit->getDef(); + if (resultDef->isSubClassOf(opVarClass)) + resultDef = resultDef->getValueAsDef("constraint"); results.push_back({name, TypeConstraint(resultDef)}); } @@ -393,3 +419,8 @@ os << "[operand] " << arg.get()->name << '\n'; } } + +auto tblgen::Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) + -> VariableDecorator { + return VariableDecorator(cast(init)->getDef()); +} diff --git a/mlir/lib/TableGen/SideEffects.cpp b/mlir/lib/TableGen/SideEffects.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/SideEffects.cpp @@ -0,0 +1,51 @@ +//===- SideEffects.cpp - SideEffect classes -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/SideEffects.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// SideEffect +//===----------------------------------------------------------------------===// + +StringRef SideEffect::getName() const { + return def->getValueAsString("effect"); +} + +StringRef SideEffect::getBaseName() const { + return def->getValueAsString("baseEffect"); +} + +StringRef SideEffect::getInterfaceTrait() const { + return def->getValueAsString("interfaceTrait"); +} + +StringRef SideEffect::getResource() const { + auto value = def->getValueAsString("resource"); + return value.empty() ? "::mlir::SideEffects::DefaultResource" : value; +} + +bool SideEffect::classof(const Operator::VariableDecorator *var) { + return var->getDef().isSubClassOf("SideEffect"); +} + +//===----------------------------------------------------------------------===// +// SideEffectsTrait +//===----------------------------------------------------------------------===// + +Operator::var_decorator_range SideEffectTrait::getEffects() const { + auto *listInit = dyn_cast(def->getValueInit("effects")); + return {listInit->begin(), listInit->end()}; +} + +bool SideEffectTrait::classof(const OpTrait *t) { + return t->getDef().isSubClassOf("SideEffectsTraitBase"); +} diff --git a/mlir/test/mlir-tblgen/op-side-effects.td b/mlir/test/mlir-tblgen/op-side-effects.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-side-effects.td @@ -0,0 +1,26 @@ +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s + +include "mlir/IR/SideEffects.td" + +def TEST_Dialect : Dialect { + let name = "test"; +} +class TEST_Op traits = []> : + Op; + +def SideEffectOpA : TEST_Op<"side_effect_op_a"> { + let arguments = (ins Arg<"", Variadic, [MemRead]>); + let results = (outs Res<"", AnyMemRef, [MemAlloc<"CustomResource">]>); +} + +def SideEffectOpB : TEST_Op<"side_effect_op_b", + [MemoryEffects<[MemWrite<"CustomResource">]>]>; + +// CHECK: void SideEffectOpA::getEffects +// CHECK: for (Value value : getODSOperands(0)) +// CHECK: effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get()); +// CHECK: for (Value value : getODSResults(0)) +// CHECK: effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get()); + +// CHECK: void SideEffectOpB::getEffects +// CHECK: effects.emplace_back(MemoryEffects::Write::get(), CustomResource::get()); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -20,6 +20,7 @@ #include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/SideEffects.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -279,6 +280,9 @@ // Generate the OpInterface methods. void genOpInterfaceMethods(); + // Generate the side effect interface methods. + void genSideEffectInterfaceMethods(); + private: // The TableGen record for this op. // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly, @@ -320,6 +324,7 @@ genFolderDecls(); genOpInterfaceMethods(); generateOpFormat(op, opClass); + genSideEffectInterfaceMethods(); } void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) { @@ -1158,6 +1163,75 @@ } } +void OpEmitter::genSideEffectInterfaceMethods() { + enum class EffectKind : unsigned { Operand, Result, Static }; + struct EffectLocation { + /// The effect applied. + SideEffect effect; + + /// The index if the kind is either operand or result. + unsigned index : 30; + + /// The kind of the location. + EffectKind kind : 2; + }; + + StringMap> interfaceEffects; + auto resolveDecorators = [&](Operator::var_decorator_range decorators, + unsigned index, EffectKind kind) { + for (auto decorator : decorators) + if (SideEffect *effect = dyn_cast(&decorator)) + interfaceEffects[effect->getInterfaceTrait()].push_back( + EffectLocation{*effect, index, kind}); + }; + + // Collect effects that were specified via: + /// Traits. + for (const auto &trait : op.getTraits()) + if (const auto *opTrait = dyn_cast(&trait)) + resolveDecorators(opTrait->getEffects(), /*index=*/0, EffectKind::Static); + /// Operands. + for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) { + if (op.getArg(i).is()) { + resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand); + ++operandIt; + } + } + /// Results. + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) + resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result); + + for (auto &it : interfaceEffects) { + StringRef baseEffect = it.second.front().effect.getBaseName(); + auto effectsParam = + llvm::formatv( + "SmallVectorImpl> &effects", + baseEffect) + .str(); + + // Generate the 'getEffects' method. + auto &getEffects = opClass.newMethod("void", "getEffects", effectsParam); + auto &body = getEffects.body(); + + // Add effect instances for each of the locations marked on the operation. + for (auto &location : it.second) { + if (location.kind != EffectKind::Static) { + body << " for (Value value : getODS" + << (location.kind == EffectKind::Operand ? "Operands" : "Results") + << "(" << location.index << "))\n "; + } + + body << " effects.emplace_back(" << location.effect.getName() + << "::get()"; + + // If the effect isn't static, it has a specific value attached to it. + if (location.kind != EffectKind::Static) + body << ", value"; + body << ", " << location.effect.getResource() << "::get());\n"; + } + } +} + void OpEmitter::genParser() { if (!hasStringAttribute(def, "parser") || hasStringAttribute(def, "assemblyFormat"))