diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -126,7 +126,7 @@ llvm::StringRef name = {}, mlir::ValueRange shape = {}, mlir::ValueRange lenParams = {}, - llvm::ArrayRef attrs = {}); + mlir::AttributeRange attrs = {}); /// Create an unnamed and untracked temporary on the stack. mlir::Value createTemporary(mlir::Location loc, mlir::Type type, @@ -135,13 +135,13 @@ } mlir::Value createTemporary(mlir::Location loc, mlir::Type type, - llvm::ArrayRef attrs) { + mlir::AttributeRange attrs) { return createTemporary(loc, type, llvm::StringRef{}, {}, {}, attrs); } mlir::Value createTemporary(mlir::Location loc, mlir::Type type, llvm::StringRef name, - llvm::ArrayRef attrs) { + mlir::AttributeRange attrs) { return createTemporary(loc, type, name, {}, {}, attrs); } diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -34,7 +34,7 @@ def fir_OneResultOpBuilder : OpBuilder<(ins "mlir::Type":$resultType, "mlir::ValueRange":$operands, - CArg<"llvm::ArrayRef", "{}">:$attributes), + CArg<"mlir::AttributeRange", "{}">:$attributes), [{ if (resultType) $_state.addTypes(resultType); @@ -143,28 +143,28 @@ OpBuilder<(ins "mlir::Type":$inType, "llvm::StringRef":$uniqName, "llvm::StringRef":$bindcName, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::ValueRange", "{}">:$shape, - CArg<"llvm::ArrayRef", "{}">:$attributes)>, + CArg<"mlir::AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "mlir::Type":$inType, "llvm::StringRef":$uniqName, "llvm::StringRef":$bindcName, "bool":$pinned, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::ValueRange", "{}">:$shape, - CArg<"llvm::ArrayRef", "{}">:$attributes)>, + CArg<"mlir::AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "mlir::Type":$inType, "llvm::StringRef":$uniqName, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::ValueRange", "{}">:$shape, - CArg<"llvm::ArrayRef", "{}">:$attributes)>, + CArg<"mlir::AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "mlir::Type":$inType, "llvm::StringRef":$uniqName, "bool":$pinned, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::ValueRange", "{}">:$shape, - CArg<"llvm::ArrayRef", "{}">:$attributes)>, + CArg<"mlir::AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "mlir::Type":$inType, "bool":$pinned, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::ValueRange", "{}">:$shape, - CArg<"llvm::ArrayRef", "{}">:$attributes)>, + CArg<"mlir::AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "mlir::Type":$inType, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::ValueRange", "{}">:$shape, - CArg<"llvm::ArrayRef", "{}">:$attributes)>]; + CArg<"mlir::AttributeRange", "{}">:$attributes)>]; let verifier = "return ::verify(*this);"; @@ -214,15 +214,15 @@ OpBuilder<(ins "mlir::Type":$in_type, "llvm::StringRef":$uniq_name, "llvm::StringRef":$bindc_name, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::ValueRange", "{}">:$shape, - CArg<"llvm::ArrayRef", "{}">:$attributes)>, + CArg<"mlir::AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "mlir::Type":$in_type, "llvm::StringRef":$uniq_name, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::ValueRange", "{}">:$shape, - CArg<"llvm::ArrayRef", "{}">:$attributes)>, + CArg<"mlir::AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "mlir::Type":$in_type, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"mlir::ValueRange", "{}">:$shape, - CArg<"llvm::ArrayRef", "{}">:$attributes)>]; + CArg<"mlir::AttributeRange", "{}">:$attributes)>]; let verifier = "return ::verify(*this);"; @@ -522,7 +522,7 @@ "llvm::ArrayRef":$compareOperands, "llvm::ArrayRef":$destinations, CArg<"llvm::ArrayRef", "{}">:$destOperands, - CArg<"llvm::ArrayRef", "{}">:$attributes), + CArg<"mlir::AttributeRange", "{}">:$attributes), [{ $_state.addOperands(selector); llvm::SmallVector ivalues; @@ -715,13 +715,13 @@ "llvm::ArrayRef":$cmpOperands, "llvm::ArrayRef":$destinations, CArg<"llvm::ArrayRef", "{}">:$destOperands, - CArg<"llvm::ArrayRef", "{}">:$attributes)>, + CArg<"mlir::AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "mlir::Value":$selector, "llvm::ArrayRef":$compareAttrs, "llvm::ArrayRef":$cmpOpList, "llvm::ArrayRef":$destinations, CArg<"llvm::ArrayRef", "{}">:$destOperands, - CArg<"llvm::ArrayRef", "{}">:$attributes)>]; + CArg<"mlir::AttributeRange", "{}">:$attributes)>]; let parser = "return parseSelectCase(parser, result);"; @@ -759,7 +759,7 @@ "llvm::ArrayRef":$typeOperands, "llvm::ArrayRef":$destinations, CArg<"llvm::ArrayRef", "{}">:$destOperands, - CArg<"llvm::ArrayRef", "{}">:$attributes)>]; + CArg<"mlir::AttributeRange", "{}">:$attributes)>]; let parser = "return parseSelectType(parser, result);"; @@ -2133,7 +2133,7 @@ "mlir::Value":$step, CArg<"bool", "false">:$unordered, CArg<"bool", "false">:$finalCountValue, CArg<"mlir::ValueRange", "llvm::None">:$iterArgs, - CArg<"llvm::ArrayRef", "{}">:$attributes)> + CArg<"mlir::AttributeRange", "{}">:$attributes)> ]; let extraClassDeclaration = [{ @@ -2277,7 +2277,7 @@ "mlir::Value":$step, "mlir::Value":$iterate, CArg<"bool", "false">:$finalCountValue, CArg<"mlir::ValueRange", "llvm::None">:$iterArgs, - CArg<"llvm::ArrayRef", "{}">:$attributes)> + CArg<"mlir::AttributeRange", "{}">:$attributes)> ]; let extraClassDeclaration = [{ @@ -2755,23 +2755,23 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "llvm::StringRef":$name, "mlir::Type":$type, - CArg<"llvm::ArrayRef", "{}">:$attrs)>, + CArg<"mlir::AttributeRange", "{}">:$attrs)>, OpBuilder<(ins "llvm::StringRef":$name, "bool":$isConstant, "mlir::Type":$type, - CArg<"llvm::ArrayRef", "{}">:$attrs)>, + CArg<"mlir::AttributeRange", "{}">:$attrs)>, OpBuilder<(ins "llvm::StringRef":$name, "mlir::Type":$type, CArg<"mlir::StringAttr", "{}">:$linkage, - CArg<"llvm::ArrayRef", "{}">:$attrs)>, + CArg<"mlir::AttributeRange", "{}">:$attrs)>, OpBuilder<(ins "llvm::StringRef":$name, "bool":$isConstant, "mlir::Type":$type, CArg<"mlir::StringAttr", "{}">:$linkage, - CArg<"llvm::ArrayRef", "{}">:$attrs)>, + CArg<"mlir::AttributeRange", "{}">:$attrs)>, OpBuilder<(ins "llvm::StringRef":$name, "mlir::Type":$type, "mlir::Attribute":$initVal, CArg<"mlir::StringAttr", "{}">:$linkage, - CArg<"llvm::ArrayRef", "{}">:$attrs)>, + CArg<"mlir::AttributeRange", "{}">:$attrs)>, OpBuilder<(ins "llvm::StringRef":$name, "bool":$isConstant, "mlir::Type":$type, "mlir::Attribute":$initVal, CArg<"mlir::StringAttr", "{}">:$linkage, - CArg<"llvm::ArrayRef", "{}">:$attrs)>, + CArg<"mlir::AttributeRange", "{}">:$attrs)>, ]; let extraClassDeclaration = [{ @@ -2879,7 +2879,7 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "llvm::StringRef":$name, "mlir::Type":$type, - CArg<"llvm::ArrayRef", "{}">:$attrs), + CArg<"mlir::AttributeRange", "{}">:$attrs), [{ $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), $_builder.getStringAttr(name)); diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h --- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h +++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h @@ -52,12 +52,12 @@ /// FuncOp is created, and that new FuncOp is returned. mlir::FuncOp createFuncOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name, mlir::FunctionType type, - llvm::ArrayRef attrs = {}); + mlir::AttributeRange attrs = {}); /// Get or create a GlobalOp in a module. fir::GlobalOp createGlobalOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name, mlir::Type type, - llvm::ArrayRef attrs = {}); + mlir::AttributeRange attrs = {}); /// Attribute to mark Fortran entities with the CONTIGUOUS attribute. constexpr llvm::StringRef getContiguousAttrName() { return "fir.contiguous"; } diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -193,11 +193,12 @@ /// Create a temporary variable on the stack. Anonymous temporaries have no /// `name` value. Temporaries do not require a uniqued name. -mlir::Value -fir::FirOpBuilder::createTemporary(mlir::Location loc, mlir::Type type, - llvm::StringRef name, mlir::ValueRange shape, - mlir::ValueRange lenParams, - llvm::ArrayRef attrs) { +mlir::Value fir::FirOpBuilder::createTemporary(mlir::Location loc, + mlir::Type type, + llvm::StringRef name, + mlir::ValueRange shape, + mlir::ValueRange lenParams, + mlir::AttributeRange attrs) { llvm::SmallVector dynamicShape = elideExtentsAlreadyInType(type, shape); llvm::SmallVector dynamicLength = diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -169,7 +169,7 @@ mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, mlir::ValueRange typeparams, mlir::ValueRange shape, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { auto nameAttr = builder.getStringAttr(uniqName); build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, /*pinned=*/false, typeparams, shape); @@ -180,7 +180,7 @@ mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, bool pinned, mlir::ValueRange typeparams, mlir::ValueRange shape, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { auto nameAttr = builder.getStringAttr(uniqName); build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, pinned, typeparams, shape); @@ -191,7 +191,7 @@ mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, llvm::StringRef bindcName, mlir::ValueRange typeparams, mlir::ValueRange shape, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { auto nameAttr = uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); auto bindcAttr = @@ -206,7 +206,7 @@ llvm::StringRef uniqName, llvm::StringRef bindcName, bool pinned, mlir::ValueRange typeparams, mlir::ValueRange shape, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { auto nameAttr = uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); auto bindcAttr = @@ -219,7 +219,7 @@ void fir::AllocaOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, mlir::ValueRange typeparams, mlir::ValueRange shape, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, /*pinned=*/false, typeparams, shape); result.addAttributes(attributes); @@ -229,7 +229,7 @@ mlir::OperationState &result, mlir::Type inType, bool pinned, mlir::ValueRange typeparams, mlir::ValueRange shape, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, pinned, typeparams, shape); result.addAttributes(attributes); @@ -276,7 +276,7 @@ mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, mlir::ValueRange typeparams, mlir::ValueRange shape, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { auto nameAttr = builder.getStringAttr(uniqName); build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, {}, typeparams, shape); @@ -287,7 +287,7 @@ mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, llvm::StringRef bindcName, mlir::ValueRange typeparams, mlir::ValueRange shape, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { auto nameAttr = builder.getStringAttr(uniqName); auto bindcAttr = builder.getStringAttr(bindcName); build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, @@ -298,7 +298,7 @@ void fir::AllocMemOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, mlir::ValueRange typeparams, mlir::ValueRange shape, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { build(builder, result, wrapAllocMemResultType(inType), inType, {}, {}, typeparams, shape); result.addAttributes(attributes); @@ -1233,7 +1233,7 @@ void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, bool isConstant, Type type, Attribute initialVal, StringAttr linkage, - ArrayRef attrs) { + AttributeRange attrs) { result.addRegion(); result.addAttribute(typeAttrName(result.name), mlir::TypeAttr::get(type)); result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), @@ -1251,31 +1251,30 @@ void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, Type type, Attribute initialVal, - StringAttr linkage, ArrayRef attrs) { + StringAttr linkage, AttributeRange attrs) { build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, bool isConstant, Type type, - StringAttr linkage, ArrayRef attrs) { + StringAttr linkage, AttributeRange attrs) { build(builder, result, name, isConstant, type, {}, linkage, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, Type type, StringAttr linkage, - ArrayRef attrs) { + AttributeRange attrs) { build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, StringRef name, bool isConstant, Type type, - ArrayRef attrs) { + AttributeRange attrs) { build(builder, result, name, isConstant, type, StringAttr{}, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, - StringRef name, Type type, - ArrayRef attrs) { + StringRef name, Type type, AttributeRange attrs) { build(builder, result, name, /*isConstant=*/false, type, attrs); } @@ -1535,7 +1534,7 @@ mlir::Value ub, mlir::Value step, mlir::Value iterate, bool finalCountValue, mlir::ValueRange iterArgs, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { result.addOperands({lb, ub, step, iterate}); if (finalCountValue) { result.addTypes(builder.getIndexType()); @@ -1861,7 +1860,7 @@ mlir::OperationState &result, mlir::Value lb, mlir::Value ub, mlir::Value step, bool unordered, bool finalCountValue, mlir::ValueRange iterArgs, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { result.addOperands({lb, ub, step}); result.addOperands(iterArgs); if (finalCountValue) { @@ -2287,8 +2286,7 @@ getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, StringRef offsetAttr) { Operation *owner = operands.getOwner(); - NamedAttribute targetOffsetAttr = - *owner->getAttrDictionary().getNamed(offsetAttr); + NamedAttribute targetOffsetAttr = *owner->getAttrs().getNamed(offsetAttr); return getSubOperands( pos, operands, targetOffsetAttr.second.cast(), mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); @@ -2486,7 +2484,7 @@ llvm::ArrayRef cmpOperands, llvm::ArrayRef destinations, llvm::ArrayRef destOperands, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { result.addOperands(selector); result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); llvm::SmallVector operOffs; @@ -2539,7 +2537,7 @@ llvm::ArrayRef cmpOpList, llvm::ArrayRef destinations, llvm::ArrayRef destOperands, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { llvm::SmallVector cmpOpers; auto iter = cmpOpList.begin(); for (auto &attr : compareAttrs) { @@ -2750,7 +2748,7 @@ llvm::ArrayRef typeOperands, llvm::ArrayRef destinations, llvm::ArrayRef destOperands, - llvm::ArrayRef attributes) { + mlir::AttributeRange attributes) { result.addOperands(selector); result.addAttribute(getCasesAttr(), builder.getArrayAttr(typeOperands)); const auto count = destinations.size(); @@ -3195,7 +3193,7 @@ mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, StringRef name, mlir::FunctionType type, - llvm::ArrayRef attrs) { + mlir::AttributeRange attrs) { if (auto f = module.lookupSymbol(name)) return f; mlir::OpBuilder modBuilder(module.getBodyRegion()); @@ -3207,7 +3205,7 @@ fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, StringRef name, mlir::Type type, - llvm::ArrayRef attrs) { + mlir::AttributeRange attrs) { if (auto g = module.lookupSymbol(name)) return g; mlir::OpBuilder modBuilder(module.getBodyRegion()); diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -236,8 +236,8 @@ several `build()` methods generated for it. One of them has aggregated parameters for result types, operands, and attributes in the signature: `void COp::build(..., ArrayRef resultTypes, Array operands, -ArrayRef attr)`. The pattern in the above calls this `build()` -method for constructing the `COp`. +AttributeRange attr)`. The pattern in the above calls this `build()` method for +constructing the `COp`. In general, arguments in the result pattern will be passed directly to the `build()` method to leverage the auto-generated `build()` method, list them in diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -354,8 +354,8 @@ Both operation traits, [interfaces](Interfaces.md/#utilizing-the-ods-framework), and constraints involving multiple operands/attributes/results are provided as -the third template parameter to the `Op` class. They should be deriving from -the `OpTrait` class. See [Constraints](#constraints) for more information. +the third template parameter to the `Op` class. They should be deriving from the +`OpTrait` class. See [Constraints](#constraints) for more information. ### Builder methods @@ -389,7 +389,7 @@ static void build(OpBuilder &odsBuilder, OperationState &odsState, ArrayRef resultTypes, ValueRange operands, - ArrayRef attributes); + AttributeRange attributes); // Each result-type/operand/attribute has a separate parameter. The parameters // for attributes are of mlir::Attribute types. @@ -416,7 +416,7 @@ // All operands/attributes have aggregate parameters. // Generated if return type can be inferred. static void build(OpBuilder &odsBuilder, OperationState &odsState, - ValueRange operands, ArrayRef attributes); + ValueRange operands, AttributeRange attributes); // (And manually specified builders depending on the specific op.) ``` @@ -561,7 +561,7 @@ Verification code will be automatically generated for [constraints](#constraints) specified on various entities of the op. To perform -_additional_ verification, you can use +*additional* verification, you can use ```tablegen let verifier = [{ @@ -830,13 +830,14 @@ ##### Unit Attributes -In MLIR, the [`unit` Attribute](Dialects/Builtin.md/#unitattr) is special in that it -only has one possible value, i.e. it derives meaning from its existence. When a -unit attribute is used to anchor an optional group and is not the first element -of the group, the presence of the unit attribute can be directly correlated with -the presence of the optional group itself. As such, in these situations the unit -attribute will not be printed or present in the output and will be automatically -inferred when parsing by the presence of the optional group itself. +In MLIR, the [`unit` Attribute](Dialects/Builtin.md/#unitattr) is special in +that it only has one possible value, i.e. it derives meaning from its existence. +When a unit attribute is used to anchor an optional group and is not the first +element of the group, the presence of the unit attribute can be directly +correlated with the presence of the optional group itself. As such, in these +situations the unit attribute will not be printed or present in the output and +will be automatically inferred when parsing by the presence of the optional +group itself. For example, the following operation: @@ -999,7 +1000,7 @@ #### Operand adaptors -For each operation, we automatically generate an _operand adaptor_. This class +For each operation, we automatically generate an *operand adaptor*. This class solves the problem of accessing operands provided as a list of `Value`s without using "magic" constants. The operand adaptor takes a reference to an array of `Value` and provides methods with the same names as those in the operation class @@ -1116,11 +1117,11 @@ information of the current operation. * `$_self` will be replaced with the entity this predicate is attached to. E.g., `BoolAttr` is an attribute constraint that wraps a - `CPred<"$_self.isa()">`. Then for `BoolAttr:$attr`,`$_self` will be - replaced by `$attr`. For type constraints, it's a little bit special since - we want the constraints on each type definition reads naturally and we want - to attach type constraints directly to an operand/result, `$_self` will be - replaced by the operand/result's type. E.g., for `F32` in `F32:$operand`, + `CPred<"$_self.isa()">`. Then for `BoolAttr:$attr`,`$_self` will + be replaced by `$attr`. For type constraints, it's a little bit special + since we want the constraints on each type definition reads naturally and we + want to attach type constraints directly to an operand/result, `$_self` will + be replaced by the operand/result's type. E.g., for `F32` in `F32:$operand`, its `$_self` will be expanded as `operand(...).getType()`. TODO: Reconsider the leading symbol for special placeholders. Eventually we want @@ -1199,9 +1200,9 @@ bitwidth. ODS attributes are defined as having a storage type (corresponding to a backing -`mlir::Attribute` that _stores_ the attribute), a return type (corresponding to -the C++ _return_ type of the generated helper getters) as well as a method -to convert between the internal storage and the helper method. +`mlir::Attribute` that *stores* the attribute), a return type (corresponding to +the C++ *return* type of the generated helper getters) as well as a method to +convert between the internal storage and the helper method. ### Attribute decorators @@ -1429,10 +1430,10 @@ ## Type Definitions -MLIR defines the `TypeDef` class hierarchy to enable generation of data types from -their specifications. A type is defined by specializing the `TypeDef` class with -concrete contents for all the fields it requires. For example, an integer type -could be defined as: +MLIR defines the `TypeDef` class hierarchy to enable generation of data types +from their specifications. A type is defined by specializing the `TypeDef` class +with concrete contents for all the fields it requires. For example, an integer +type could be defined as: ```tablegen // All of the types will extend this class. @@ -1680,10 +1681,10 @@ This builder is identical to the one that will be automatically generated for `MyType`. The `context` parameter is implicitly added by the generator, and is -used when building the Type instance (with `Base::get`). The distinction -here is that we can provide the implementation of this `get` method. With this -style of builder definition only the declaration is generated, the implementor -of `MyType` will need to provide a definition of `MyType::get`. +used when building the Type instance (with `Base::get`). The distinction here is +that we can provide the implementation of this `get` method. With this style of +builder definition only the declaration is generated, the implementor of +`MyType` will need to provide a definition of `MyType::get`. The second builder will generate the declaration of a builder method that looks like: diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -110,7 +110,8 @@ // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); + typename BinaryOp::Adaptor binaryAdaptor(builder.getContext(), + memRefOperands); // Generate loads for the element of 'lhs' and 'rhs' at the inner // loop. @@ -253,7 +254,8 @@ // Generate an adaptor for the remapped operands of the // TransposeOp. This allows for using the nice named // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); + toy::TransposeOpAdaptor transposeAdaptor( + builder.getContext(), memRefOperands); Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -110,7 +110,8 @@ // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); + typename BinaryOp::Adaptor binaryAdaptor(builder.getContext(), + memRefOperands); // Generate loads for the element of 'lhs' and 'rhs' at the inner // loop. @@ -253,7 +254,8 @@ // Generate an adaptor for the remapped operands of the // TransposeOp. This allows for using the nice named // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); + toy::TransposeOpAdaptor transposeAdaptor( + builder.getContext(), memRefOperands); Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -110,7 +110,8 @@ // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); + typename BinaryOp::Adaptor binaryAdaptor(builder.getContext(), + memRefOperands); // Generate loads for the element of 'lhs' and 'rhs' at the inner // loop. @@ -253,7 +254,8 @@ // Generate an adaptor for the remapped operands of the // TransposeOp. This allows for using the nice named // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); + toy::TransposeOpAdaptor transposeAdaptor( + builder.getContext(), memRefOperands); Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -142,8 +142,8 @@ /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), - rewriter); + rewrite(cast(op), + OpAdaptor(getContext(), operands, op->getAttrs()), rewriter); } LogicalResult match(Operation *op) const final { return match(cast(op)); @@ -152,7 +152,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary()), + OpAdaptor(getContext(), operands, op->getAttrs()), rewriter); } @@ -168,7 +168,7 @@ matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (succeeded(match(op))) { - rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter); + rewrite(op, OpAdaptor(getContext(), operands, op->getAttrs()), rewriter); return success(); } return failure(); diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -140,7 +140,7 @@ let builders = [ OpBuilder<(ins "Value":$operand, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -202,7 +202,7 @@ OpBuilder<(ins "StringRef":$name, "FunctionType":$type, CArg<"TypeRange", "{}">:$workgroupAttributions, CArg<"TypeRange", "{}">:$privateAttributions, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"AttributeRange", "{}">:$attrs)> ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -68,7 +68,7 @@ def LLVM_OneResultOpBuilder : OpBuilder<(ins "Type":$resultType, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), + CArg<"AttributeRange", "{}">:$attributes), [{ if (resultType) $_state.addTypes(resultType); $_state.addOperands(operands); @@ -79,7 +79,7 @@ def LLVM_ZeroResultOpBuilder : OpBuilder<(ins "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), + CArg<"AttributeRange", "{}">:$attributes), [{ $_state.addOperands(operands); for (auto namedAttr : attributes) { @@ -91,7 +91,7 @@ // to indicate no result. def LLVM_VoidResultTypeOpBuilder : OpBuilder<(ins "Type":$resultType, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), + CArg<"AttributeRange", "{}">:$attributes), [{ assert(isCompatibleType(resultType) && "result must be an LLVM type"); assert(resultType.isa() && @@ -103,7 +103,7 @@ // Opaque builder used for terminator operations that contain successors. def LLVM_TerminatorPassthroughOpBuilder : OpBuilder<(ins "ValueRange":$operands, "SuccessorRange":$destinations, - CArg<"ArrayRef", "{}">:$attributes), + CArg<"AttributeRange", "{}">:$attributes), [{ $_state.addOperands(operands); $_state.addSuccessors(destinations); @@ -515,7 +515,7 @@ let results = (outs Variadic); let builders = [ OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), [{ + CArg<"AttributeRange", "{}">:$attributes), [{ Type resultType = func.getType().getReturnType(); if (!resultType.isa()) $_state.addTypes(resultType); @@ -544,7 +544,7 @@ }]; let builders = [ OpBuilder<(ins "Value":$vector, "Value":$position, - CArg<"ArrayRef", "{}">:$attrs)>]; + CArg<"AttributeRange", "{}">:$attrs)>]; let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseExtractElementOp(parser, result); }]; let printer = [{ printExtractElementOp(p, *this); }]; @@ -600,7 +600,7 @@ }]; let builders = [ OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask, - CArg<"ArrayRef", "{}">:$attrs)>]; + CArg<"AttributeRange", "{}">:$attrs)>]; let verifier = [{ auto type1 = v1().getType(); auto type2 = v2().getType(); @@ -873,7 +873,7 @@ let builders = [ OpBuilder<(ins "GlobalOp":$global, - CArg<"ArrayRef", "{}">:$attrs), + CArg<"AttributeRange", "{}">:$attrs), [{ build($_builder, $_state, LLVM::LLVMPointerType::get(global.getType(), global.addr_space()), @@ -881,7 +881,7 @@ $_state.addAttributes(attrs); }]>, OpBuilder<(ins "LLVMFuncOp":$func, - CArg<"ArrayRef", "{}">:$attrs), + CArg<"AttributeRange", "{}">:$attrs), [{ build($_builder, $_state, LLVM::LLVMPointerType::get(func.getType()), func.getName()); @@ -1122,7 +1122,7 @@ CArg<"uint64_t", "0">:$alignment, CArg<"unsigned", "0">:$addrSpace, CArg<"bool", "false">:$dsoLocal, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"AttributeRange", "{}">:$attrs)> ]; let extraClassDeclaration = [{ @@ -1258,7 +1258,7 @@ OpBuilder<(ins "StringRef":$name, "Type":$type, CArg<"Linkage", "Linkage::External">:$linkage, CArg<"bool", "false">:$dsoLocal, - CArg<"ArrayRef", "{}">:$attrs, + CArg<"AttributeRange", "{}">:$attrs, CArg<"ArrayRef", "{}">:$argAttrs)> ]; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -123,7 +123,7 @@ build($_builder, $_state, ValueRange{}, staticShape, elementType); }]>, OpBuilder<(ins "ArrayRef":$sizes, "Type":$elementType, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"AttributeRange", "{}">:$attrs)> ]; let hasCanonicalizer = 1; @@ -313,17 +313,17 @@ OpBuilder<(ins "Value":$source, "ArrayRef":$staticLow, "ArrayRef":$staticHigh, "ValueRange":$low, "ValueRange":$high, CArg<"bool", "false">:$nofold, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a PadTensorOp with all dynamic entries. OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high, CArg<"bool", "false">:$nofold, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a PadTensorOp with mixed static and dynamic entries and custom // result type. If the type passed is nullptr, it is inferred. OpBuilder<(ins "Type":$resultType, "Value":$source, "ArrayRef":$low, "ArrayRef":$high, CArg<"bool", "false">:$nofold, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, ]; let hasCanonicalizer = 1; @@ -365,10 +365,10 @@ // `src` and `reassociation`. OpBuilder<(ins "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, OpBuilder<(ins "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + CArg<"AttributeRange", "{}">:$attrs), [{ auto reassociationMaps = convertReassociationMapsToIndices($_builder, reassociation); @@ -379,7 +379,7 @@ // be either a contracting or expanding reshape. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + CArg<"AttributeRange", "{}">:$attrs), [{ build($_builder, $_state, resultType, src, attrs); $_state.addAttribute("reassociation", @@ -387,7 +387,7 @@ }]>, OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + CArg<"AttributeRange", "{}">:$attrs), [{ auto reassociationMaps = convertReassociationMapsToIndices($_builder, reassociation); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -122,7 +122,7 @@ OpBuilder<(ins "Value":$input, "Value":$output, CArg<"AffineMap", "AffineMap()">:$inputPermutation, CArg<"AffineMap", "AffineMap()">:$outputPermutation, - CArg<"ArrayRef", "{}">:$attrs)>]; + CArg<"AttributeRange", "{}">:$attrs)>]; let extraClassDeclaration = structuredOpsDecls # [{ ValueRange inputs() { return getOperands().take_front(); } @@ -338,21 +338,21 @@ "ArrayRef":$iteratorTypes, "StringRef":$doc, "StringRef":$libraryCall, CArg<"function_ref", "nullptr">, - CArg<"ArrayRef", "{}">:$attributes)>, + CArg<"AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, "StringRef":$doc, "StringRef":$libraryCall, CArg<"function_ref", "nullptr">, - CArg<"ArrayRef", "{}">:$attributes)>, + CArg<"AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, CArg<"function_ref", "nullptr">, - CArg<"ArrayRef", "{}">:$attributes)>, + CArg<"AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, CArg<"function_ref", "nullptr">, - CArg<"ArrayRef", "{}">:$attributes)> + CArg<"AttributeRange", "{}">:$attributes)> ]; let extraClassDeclaration = structuredOpsBaseDecls # [{ diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1053,17 +1053,17 @@ OpBuilder<(ins "MemRefType":$resultType, "Value":$source, "OpFoldResult":$offset, "ArrayRef":$sizes, "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a ReinterpretCastOp with static entries. OpBuilder<(ins "MemRefType":$resultType, "Value":$source, "int64_t":$offset, "ArrayRef":$sizes, "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a ReinterpretCastOp with dynamic entries. OpBuilder<(ins "MemRefType":$resultType, "Value":$source, "Value":$offset, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"AttributeRange", "{}">:$attrs)> ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -1170,10 +1170,10 @@ // `src` and `reassociation`. OpBuilder<(ins "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, OpBuilder<(ins "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + CArg<"AttributeRange", "{}">:$attrs), [{ auto reassociationMaps = convertReassociationMapsToIndices($_builder, reassociation); @@ -1184,7 +1184,7 @@ // be either a contracting or expanding reshape. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + CArg<"AttributeRange", "{}">:$attrs), [{ build($_builder, $_state, resultType, src, attrs); $_state.addAttribute("reassociation", @@ -1192,7 +1192,7 @@ }]>, OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + CArg<"AttributeRange", "{}">:$attrs), [{ auto reassociationMaps = convertReassociationMapsToIndices($_builder, reassociation); @@ -1528,32 +1528,32 @@ // result type. If the type passed is nullptr, it is inferred. OpBuilder<(ins "Value":$source, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a SubViewOp with mixed static and dynamic entries and inferred // result type. OpBuilder<(ins "MemRefType":$resultType, "Value":$source, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a SubViewOp with static entries and custom result type. If the // type passed is nullptr, it is inferred. OpBuilder<(ins "Value":$source, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a SubViewOp with static entries and inferred result type. OpBuilder<(ins "MemRefType":$resultType, "Value":$source, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a SubViewOp with dynamic entries and custom result type. If the // type passed is nullptr, it is inferred. OpBuilder<(ins "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a SubViewOp with dynamic entries and inferred result type. OpBuilder<(ins "MemRefType":$resultType, "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"AttributeRange", "{}">:$attrs)> ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -1719,7 +1719,7 @@ let builders = [ OpBuilder<(ins "Value":$in, "AffineMapAttr":$permutation, - CArg<"ArrayRef", "{}">:$attrs)>]; + CArg<"AttributeRange", "{}">:$attrs)>]; let extraClassDeclaration = [{ static StringRef getPermutationAttrName() { return "permutation"; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -117,7 +117,7 @@ let regions = (region AnyRegion:$region); let builders = [ - OpBuilder<(ins CArg<"ArrayRef", "{}">:$attributes)> + OpBuilder<(ins CArg<"AttributeRange", "{}">:$attributes)> ]; let parser = [{ return parseParallelOp(parser, result); }]; let printer = [{ return printParallelOp(p, *this); }]; @@ -238,7 +238,7 @@ let builders = [ OpBuilder<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound, "ValueRange":$step, - CArg<"ArrayRef", "{}">:$attributes)>, + CArg<"AttributeRange", "{}">:$attributes)>, OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$lowerBound, "ValueRange":$upperBound, "ValueRange":$step, "ValueRange":$privateVars, "ValueRange":$firstprivateVars, @@ -249,7 +249,7 @@ "IntegerAttr":$ordered_val, "StringAttr":$order_val, "UnitAttr":$inclusive, CArg<"bool", "true">:$buildBody)>, OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes)> + CArg<"AttributeRange", "{}">:$attributes)> ]; let regions = (region AnyRegion:$region); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td @@ -354,7 +354,7 @@ let builders = [ OpBuilder<(ins "Value":$ptr, "Value":$value, - CArg<"ArrayRef", "{}">:$namedAttrs), + CArg<"AttributeRange", "{}">:$namedAttrs), [{ $_state.addOperands(ptr); $_state.addOperands(value); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -300,7 +300,7 @@ let builders = [ OpBuilder<(ins "StringRef":$name, "FunctionType":$type, CArg<"spirv::FunctionControl", "spirv::FunctionControl::None">:$control, - CArg<"ArrayRef", "{}">:$attrs)>]; + CArg<"AttributeRange", "{}">:$attrs)>]; let hasOpcode = 0; @@ -398,7 +398,7 @@ if (initializer) $_state.addAttribute(initializerAttrName($_state.name), initializer); }]>, - OpBuilder<(ins "TypeAttr":$type, "ArrayRef":$namedAttrs), + OpBuilder<(ins "TypeAttr":$type, "AttributeRange":$namedAttrs), [{ $_state.addAttribute("type", type); $_state.addAttributes(namedAttrs); diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -217,7 +217,7 @@ // TODO: This should really be automatic. Figure out how to not need this defined. static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::mlir::AttributeRange attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { inferredReturnTypes.push_back(::mlir::IntegerType::get(context, /*width=*/1)); @@ -295,7 +295,7 @@ // TODO: This should really be automatic. Figure out how to not need this defined. static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::mlir::AttributeRange attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { inferredReturnTypes.push_back(::mlir::IntegerType::get(context, /*width=*/1)); @@ -925,7 +925,7 @@ // TODO: This should really be automatic. Figure out how to not need this defined. static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::mlir::AttributeRange attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context)); return success(); @@ -959,7 +959,7 @@ // TODO: This should really be automatic. Figure out how to not need this defined. static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::mlir::AttributeRange attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context)); return success(); diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -225,22 +225,22 @@ // inferred result type. OpBuilder<(ins "Value":$source, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build an ExtractSliceOp with mixed static and dynamic entries and custom // result type. If the type passed is nullptr, it is inferred. OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build an ExtractSliceOp with dynamic entries and custom result type. If // the type passed is nullptr, it is inferred. OpBuilder<(ins "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build an ExtractSliceOp with dynamic entries and inferred result type. OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"AttributeRange", "{}">:$attrs)> ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -500,11 +500,11 @@ OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, + CArg<"AttributeRange", "{}">:$attrs)>, // Build a InsertSliceOp with dynamic entries. OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"AttributeRange", "{}">:$attrs)> ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ diff --git a/mlir/include/mlir/IR/AttributeRange.h b/mlir/include/mlir/IR/AttributeRange.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/AttributeRange.h @@ -0,0 +1,179 @@ +#ifndef MLIR_IR_ATTRIBUTERANGE_H_ +#define MLIR_IR_ATTRIBUTERANGE_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +class DictionaryAttr; +class NamedAttrList; +namespace detail { +class OpAttributeList; +} // end namespace detail + +namespace impl { +template +std::pair +lookupAttributeSorted(IteratorT first, IteratorT last, StringRef name) { + auto length = std::distance(first, last); + using diff_t = decltype(length); + + while (length > 0) { + diff_t half = length / 2; + IteratorT mid = first + half; + auto c = mid->first.strref().compare(name); + if (c < 0) { + first = mid; + ++first; + length = length - half - 1; + } else if (c > 0) { + length = half; + } else { + return {mid, true}; + } + } + return {first, false}; +} + +template +std::pair +lookupAttributeSorted(IteratorT first, IteratorT last, Identifier name) { + auto it = std::find_if(first, last, [name](const NamedAttribute &attr) { + return attr.first == name; + }); + return {it, it != last}; +} + +namespace detail { +const NamedAttribute *lookupAttributeBig(ArrayRef attrs, + Identifier name); +inline const NamedAttribute * +lookupAttributeSmall(ArrayRef attrs, Identifier name) { + for (auto *it = attrs.begin(), *e = attrs.end(); it != e; ++it) { + if (it->first == name) + return it; + } + return nullptr; +} + +constexpr unsigned kSmallAttrList = 16; +} // end namespace detail + +inline const NamedAttribute * +lookupAttributeSpecial(ArrayRef attrs, Identifier name) { + return attrs.size() >= detail::kSmallAttrList + ? detail::lookupAttributeBig(attrs, name) + : detail::lookupAttributeSmall(attrs, name); +} + +inline const NamedAttribute * +lookupAttributeSubrange(ArrayRef attrs, Identifier name, + unsigned left, unsigned right) { + for (auto *it = attrs.begin() + left, *e = attrs.end() - right; it != e; + ++it) { + if (it->first == name) + return it; + } + return nullptr; +} + +template +std::pair +lookupAttributeUnsorted(IteratorT first, IteratorT last, StringRef name) { + auto it = std::find_if(first, last, [name](const NamedAttribute &attr) { + return attr.first == name; + }); + return {it, it != last}; +} + +template +std::pair +lookupAttributeUnsorted(IteratorT first, IteratorT last, Identifier name) { + return lookupAttributeSorted(first, last, name); +} +} // end namespace impl + +class AttributeRange : public ArrayRef { +public: + using Base = ArrayRef; + + AttributeRange(ArrayRef value = {}, bool sorted = false) + : Base(value), sorted(sorted) {} + + /// Conversion to AttributeRange from other ranges and containers of + /// NamedAttribute. + AttributeRange(const SmallVectorImpl &attrs, + bool sorted = false) + : Base(attrs), sorted(sorted) {} + AttributeRange(const std::vector &attrs, bool sorted = false) + : Base(attrs), sorted(sorted) {} + AttributeRange(const llvm::NoneType) : AttributeRange() {} + AttributeRange(std::initializer_list attrs, + bool sorted = false) + : Base(attrs), sorted(sorted) {} + AttributeRange(const NamedAttribute &attr) : Base(attr), sorted(false) {} + + /// Conversion to sorted AttributeRange from ranges and containers that are + /// known to be sorted. + AttributeRange(const NamedAttrList &attrs); + AttributeRange(DictionaryAttr attrs); + AttributeRange(const detail::OpAttributeList &attrs); + + bool isSorted() const { return sorted; } + + SmallVector toVector() const { + return llvm::to_vector::value>(*this); + } + +private: + std::pair lookupAttribute(StringRef name) const { + return isSorted() ? impl::lookupAttributeSorted(begin(), end(), name) + : impl::lookupAttributeUnsorted(begin(), end(), name); + } + const NamedAttribute *lookupAttrId(Identifier name) const { + return isSorted() && size() >= impl::detail::kSmallAttrList + ? impl::detail::lookupAttributeBig(*this, name) + : impl::detail::lookupAttributeSmall(*this, name); + } + +public: + /// Read-only dictionary-like operations. + Attribute get(StringRef name) const { + auto result = lookupAttribute(name); + return result.second ? result.first->second : Attribute(); + } + Attribute get(Identifier name) const { + auto *ptr = lookupAttrId(name); + return ptr ? ptr->second : Attribute(); + } + template + AttributeT getAs(NameT name) const { + return get(name).template dyn_cast_or_null(); + } + Optional getNamed(StringRef name) const { + auto result = lookupAttribute(name); + return result.second ? *result.first : Optional(); + } + Optional getNamed(Identifier name) const { + auto *ptr = lookupAttrId(name); + return ptr ? *ptr : Optional(); + } + + bool operator==(const AttributeRange &other) const { return equals(other); } + +private: + /// An attribute range created from a sorted NamedAttrList, DictionaryAttr, + /// or OpAttributeList is known to be sorted. Any operation that involves a + /// name lookup can be faster. + bool sorted = false; +}; + +inline llvm::hash_code hash_value(AttributeRange attrs) { + return llvm::hash_combine_range(attrs.begin(), attrs.end()); +} + +} // end namespace mlir + +#endif // MLIR_IR_ATTRIBUTERANGE_H_ diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -91,7 +91,7 @@ UnitAttr getUnitAttr(); BoolAttr getBoolAttr(bool value); - DictionaryAttr getDictionaryAttr(ArrayRef value); + DictionaryAttr getDictionaryAttr(AttributeRange value); IntegerAttr getIntegerAttr(Type type, int64_t value); IntegerAttr getIntegerAttr(Type type, const APInt &value); FloatAttr getFloatAttr(Type type, double value); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -9,6 +9,7 @@ #ifndef MLIR_IR_BUILTINATTRIBUTES_H #define MLIR_IR_BUILTINATTRIBUTES_H +#include "mlir/IR/AttributeRange.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/SubElementInterfaces.h" #include "llvm/ADT/APFloat.h" diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -344,6 +344,12 @@ // DictionaryAttr //===----------------------------------------------------------------------===// +def AttrRangeParameter : AttrOrTypeParameter<"::mlir::AttributeRange", ""> { + let allocator = [{ + $_dst = AttributeRange($_allocator.copyInto($_self), true); + }]; +} + def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [ DeclareAttrInterfaceMethods ]> { @@ -367,17 +373,20 @@ {int_attr = 10, "string attr name" = "string attribute"} ``` }]; - let parameters = (ins ArrayRefParameter<"NamedAttribute", "">:$value); + // TODO: DictionaryAttr should exist as an ArrayRef and only + // be converted to AttributeRange as needed, but this cannot be done with ODS + // attributes. + let parameters = (ins AttrRangeParameter:$value); let builders = [ - AttrBuilder<(ins CArg<"ArrayRef", "llvm::None">:$value)> + AttrBuilder<(ins CArg<"AttributeRange", "llvm::None">:$value)> ]; let extraClassDeclaration = [{ - using ValueType = ArrayRef; + using ValueType = AttributeRange; /// Construct a dictionary with an array of values that is known to already /// be sorted by name and uniqued. static DictionaryAttr getWithSorted(MLIRContext *context, - ArrayRef value); + AttributeRange value); /// Return the specified attribute if present, null otherwise. Attribute get(StringRef name) const; @@ -388,7 +397,7 @@ Optional getNamed(Identifier name) const; /// Support range iteration. - using iterator = llvm::ArrayRef::iterator; + using iterator = AttributeRange::iterator; iterator begin() const; iterator end() const; bool empty() const { return size() == 0; } @@ -397,7 +406,7 @@ /// Sorts the NamedAttributes in the array ordered by name as expected by /// getWithSorted and returns whether the values were sorted. /// Requires: uniquely named attributes. - static bool sort(ArrayRef values, + static bool sort(AttributeRange values, SmallVectorImpl &storage); /// Sorts the NamedAttributes in the array ordered by name as expected by diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -81,16 +81,16 @@ let builders = [OpBuilder<(ins "StringRef":$name, "FunctionType":$type, - CArg<"ArrayRef", "{}">:$attrs, + CArg<"AttributeRange", "{}">:$attrs, CArg<"ArrayRef", "{}">:$argAttrs) >]; let extraClassDeclaration = [{ static FuncOp create(Location location, StringRef name, FunctionType type, - ArrayRef attrs = {}); + AttributeRange attrs = {}); static FuncOp create(Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs); static FuncOp create(Location location, StringRef name, FunctionType type, - ArrayRef attrs, + AttributeRange attrs, ArrayRef argAttrs); /// Create a deep copy of this function and all of its blocks, remapping any diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -57,13 +57,13 @@ void setAllResultAttrDicts(Operation *op, ArrayRef attrs); /// Return all of the attributes for the argument at 'index'. -inline ArrayRef getArgAttrs(Operation *op, unsigned index) { +inline AttributeRange getArgAttrs(Operation *op, unsigned index) { auto argDict = getArgAttrDict(op, index); return argDict ? argDict.getValue() : llvm::None; } /// Return all of the attributes for the result at 'index'. -inline ArrayRef getResultAttrs(Operation *op, unsigned index) { +inline AttributeRange getResultAttrs(Operation *op, unsigned index) { auto resultDict = getResultAttrDict(op, index); return resultDict ? resultDict.getValue() : llvm::None; } @@ -374,7 +374,7 @@ /// exist if they are non-empty. /// Return all of the attributes for the argument at 'index'. - ArrayRef getArgAttrs(unsigned index) { + AttributeRange getArgAttrs(unsigned index) { return function_like_impl::getArgAttrs(this->getOperation(), index); } @@ -415,7 +415,7 @@ } /// Set the attributes held by the argument at 'index'. - void setArgAttrs(unsigned index, ArrayRef attributes); + void setArgAttrs(unsigned index, AttributeRange attributes); /// Set the attributes held by the argument at 'index'. `attributes` may be /// null, in which case any existing argument attributes are removed. @@ -463,7 +463,7 @@ /// exist if they are non-empty. /// Return all of the attributes for the result at 'index'. - ArrayRef getResultAttrs(unsigned index) { + AttributeRange getResultAttrs(unsigned index) { return function_like_impl::getResultAttrs(this->getOperation(), index); } @@ -504,7 +504,7 @@ } /// Set the attributes held by the result at 'index'. - void setResultAttrs(unsigned index, ArrayRef attributes); + void setResultAttrs(unsigned index, AttributeRange attributes); /// Set the attributes held by the result at 'index'. `attributes` may be /// null, in which case any existing argument attributes are removed. @@ -697,8 +697,8 @@ /// Set the attributes held by the argument at 'index'. template -void FunctionLike::setArgAttrs( - unsigned index, ArrayRef attributes) { +void FunctionLike::setArgAttrs(unsigned index, + AttributeRange attributes) { assert(index < getNumArguments() && "invalid argument number"); Operation *op = this->getOperation(); return function_like_impl::detail::setArgResAttrDict( @@ -748,8 +748,8 @@ /// Set the attributes held by the result at 'index'. template -void FunctionLike::setResultAttrs( - unsigned index, ArrayRef attributes) { +void FunctionLike::setResultAttrs(unsigned index, + AttributeRange attributes) { assert(index < getNumResults() && "invalid result number"); Operation *op = this->getOperation(); return function_like_impl::detail::setArgResAttrDict( diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2297,7 +2297,7 @@ // static void build(OpBuilder &, OperationState &odsState, // TypeRange resultTypes, // ValueRange operands, - // ArrayRef attributes); + // AttributeRange attributes); // ``` list builders = ?; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -211,7 +211,7 @@ /// You may pass omitType=true to not print a type, and pass an empty /// attribute list if you don't care for attributes. virtual void printRegionArgument(BlockArgument arg, - ArrayRef argAttrs = {}, + AttributeRange argAttrs = {}, bool omitType = false) = 0; /// Print implementations for various things an operation contains. @@ -247,13 +247,13 @@ /// dictionary with their values. elidedAttrs allows the client to ignore /// specific well known attributes, commonly used if the attribute value is /// printed some other way (like as a fixed operand). - virtual void printOptionalAttrDict(ArrayRef attrs, + virtual void printOptionalAttrDict(AttributeRange attrs, ArrayRef elidedAttrs = {}) = 0; /// If the specified operation has attributes, print out an attribute /// dictionary prefixed with 'attributes'. virtual void - printOptionalAttrDictWithKeyword(ArrayRef attrs, + printOptionalAttrDictWithKeyword(AttributeRange attrs, ArrayRef elidedAttrs = {}) = 0; /// Print the entire operation with the default generic assembly form. @@ -589,14 +589,14 @@ /// unlike `OpBuilder::getType`, this method does not implicitly insert a /// context parameter. template - T getChecked(llvm::SMLoc loc, ParamsT &&... params) { + T getChecked(llvm::SMLoc loc, ParamsT &&...params) { return T::getChecked([&] { return emitError(loc); }, std::forward(params)...); } /// A variant of `getChecked` that uses the result of `getNameLoc` to emit /// errors. template - T getChecked(ParamsT &&... params) { + T getChecked(ParamsT &&...params) { return T::getChecked([&] { return emitError(getNameLoc()); }, std::forward(params)...); } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -28,19 +28,12 @@ class alignas(8) Operation final : public llvm::ilist_node_with_parent, private llvm::TrailingObjects { + NamedAttribute, detail::OperandStorage> { public: /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, - ArrayRef attributes, - BlockRange successors, unsigned numRegions); - - /// Overload of create that takes an existing DictionaryAttr to avoid - /// unnecessarily uniquing a list of attributes. - static Operation *create(Location location, OperationName name, - TypeRange resultTypes, ValueRange operands, - DictionaryAttr attributes, BlockRange successors, + AttributeRange attributes, BlockRange successors, unsigned numRegions); /// Create a new Operation from the fields stored in `state`. @@ -49,7 +42,7 @@ /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, - DictionaryAttr attributes, + AttributeRange attributes, BlockRange successors = {}, RegionRange regions = {}); @@ -302,24 +295,29 @@ // the lifetime of an operation. /// Return all of the attributes on this operation. - ArrayRef getAttrs() { return attrs.getValue(); } + AttributeRange getAttrs() { return attrs; } + AttributeRange::iterator attr_begin() { return getAttrs().begin(); } + AttributeRange::iterator attr_end() { return getAttrs().end(); } /// Return all of the attributes on this operation as a DictionaryAttr. - DictionaryAttr getAttrDictionary() { return attrs; } + DictionaryAttr getAttrDictionary() { + return attrs.convertToDictionary(getContext()); + } /// Set the attribute dictionary on this operation. - void setAttrs(DictionaryAttr newAttrs) { - assert(newAttrs && "expected valid attribute dictionary"); - attrs = newAttrs; - } - void setAttrs(ArrayRef newAttrs) { - setAttrs(DictionaryAttr::get(getContext(), newAttrs)); - } + void setAttrs(AttributeRange newAttrs) { attrs.assign(newAttrs); } /// Return the specified attribute if present, null otherwise. Attribute getAttr(Identifier name) { return attrs.get(name); } Attribute getAttr(StringRef name) { return attrs.get(name); } + Optional getNamedAttr(Identifier name) { + return attrs.getNamed(name); + } + Optional getNamedAttr(StringRef name) { + return attrs.getNamed(name); + } + template AttrClass getAttrOfType(Identifier name) { return getAttr(name).dyn_cast_or_null(); } @@ -329,8 +327,8 @@ /// Return true if the operation has an attribute with the provided name, /// false otherwise. - bool hasAttr(Identifier name) { return static_cast(getAttr(name)); } - bool hasAttr(StringRef name) { return static_cast(getAttr(name)); } + bool hasAttr(Identifier name) { return attrs.has(name); } + bool hasAttr(StringRef name) { return attrs.has(name); } template bool hasAttrOfType(NameT &&name) { return static_cast( @@ -339,41 +337,27 @@ /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. - void setAttr(Identifier name, Attribute value) { - NamedAttrList attributes(attrs); - if (attributes.set(name, value) != value) - attrs = attributes.getDictionary(getContext()); - } - void setAttr(StringRef name, Attribute value) { - setAttr(Identifier::get(name, getContext()), value); - } + void setAttr(Identifier name, Attribute value) { attrs.set(name, value); } + void setAttr(StringRef name, Attribute value) { attrs.set(name, value); } /// Remove the attribute with the specified name if it exists. Return the /// attribute that was erased, or nullptr if there was no attribute with such /// name. - Attribute removeAttr(Identifier name) { - NamedAttrList attributes(attrs); - Attribute removedAttr = attributes.erase(name); - if (removedAttr) - attrs = attributes.getDictionary(getContext()); - return removedAttr; - } - Attribute removeAttr(StringRef name) { - return removeAttr(Identifier::get(name, getContext())); - } + Attribute removeAttr(Identifier name) { return attrs.erase(name); } + Attribute removeAttr(StringRef name) { return attrs.erase(name); } /// A utility iterator that filters out non-dialect attributes. class dialect_attr_iterator - : public llvm::filter_iterator::iterator, + : public llvm::filter_iterator { static bool filter(NamedAttribute attr) { // Dialect attributes are prefixed by the dialect name, like operations. return attr.first.strref().count('.'); } - explicit dialect_attr_iterator(ArrayRef::iterator it, - ArrayRef::iterator end) - : llvm::filter_iterator::iterator, + explicit dialect_attr_iterator(AttributeRange::iterator it, + AttributeRange::iterator end) + : llvm::filter_iterator(it, end, &filter) {} // Allow access to the constructor. @@ -404,9 +388,12 @@ for (auto attr : getAttrs()) if (!attr.first.strref().contains('.')) attrs.push_back(attr); - setAttrs(attrs.getDictionary(getContext())); + setAttrs(attrs); } + /// Get a reference to the underlying attribute storage. + detail::OpAttributeList &getOpAttributes() { return attrs; } + //===--------------------------------------------------------------------===// // Blocks //===--------------------------------------------------------------------===// @@ -600,7 +587,8 @@ private: Operation(Location location, OperationName name, unsigned numResults, unsigned numSuccessors, unsigned numRegions, - DictionaryAttr attributes, bool hasOperandStorage); + unsigned numInlineAttributes, AttributeRange attributes, + bool hasOperandStorage); // Operations are deleted through the destroy() member because they are // allocated with malloc. @@ -652,6 +640,11 @@ return getOutOfLineOpResult(resultNumber - maxInlineResults); } + /// Get the inline storage allocated for attributes. + MutableArrayRef getInlineAttrStorage() { + return {getTrailingObjects(), numInlineAttrs}; + } + /// Provide a 'getParent' method for ilist_node_with_parent methods. /// We mark it as a const function because ilist_node_with_parent specifically /// requires a 'getParent() const' method. Once ilist_node removes this @@ -672,7 +665,9 @@ const unsigned numResults; const unsigned numSuccs; - const unsigned numRegions : 31; + const unsigned numRegions : 25; + static constexpr unsigned kMaxInlineAttrs = 63; + const unsigned numInlineAttrs : 6; /// This bit signals whether this operation has an operand storage or not. The /// operand storage may be elided for operations that are known to never have @@ -683,7 +678,7 @@ OperationName name; /// This holds general named attributes for the operation. - DictionaryAttr attrs; + detail::OpAttributeList attrs; // allow ilist_traits access to 'block' field. friend struct llvm::ilist_traits; @@ -698,12 +693,15 @@ friend class llvm::ilist_node_with_parent; // This stuff is used by the TrailingObjects template. - friend llvm::TrailingObjects; size_t numTrailingObjects(OverloadToken) const { return numSuccs; } size_t numTrailingObjects(OverloadToken) const { return numRegions; } + size_t numTrailingObjects(OverloadToken) const { + return numInlineAttrs; + } }; inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) { diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -14,6 +14,7 @@ #ifndef MLIR_IR_OPERATION_SUPPORT_H #define MLIR_IR_OPERATION_SUPPORT_H +#include "mlir/IR/AttributeRange.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockSupport.h" #include "mlir/IR/Identifier.h" @@ -259,9 +260,13 @@ using size_type = size_t; NamedAttrList() : dictionarySorted({}, true) {} - NamedAttrList(ArrayRef attributes); + NamedAttrList(AttributeRange attributes); NamedAttrList(DictionaryAttr attributes); NamedAttrList(const_iterator in_start, const_iterator in_end); + template + NamedAttrList(IteratorT in_start, IteratorT in_end) { + assign(in_start, in_end); + } bool operator!=(const NamedAttrList &other) const { return !(*this == other); @@ -299,11 +304,15 @@ /// Replaces the attributes with new list of attributes. void assign(const_iterator in_start, const_iterator in_end); - /// Replaces the attributes with new list of attributes. - void assign(ArrayRef range) { - assign(range.begin(), range.end()); + template + void assign(IteratorT in_start, IteratorT in_end) { + dictionarySorted.setPointerAndInt(nullptr, false); + attrs.assign(in_start, in_end); } + /// Replaces the attributes with new list of attributes. + void assign(AttributeRange range) { assign(range.begin(), range.end()); } + bool empty() const { return attrs.empty(); } void reserve(size_type N) { attrs.reserve(N); } @@ -323,7 +332,7 @@ DictionaryAttr getDictionary(MLIRContext *context) const; /// Return all of the attributes on this operation. - ArrayRef getAttrs() const; + AttributeRange getAttrs() const; /// Return the specified attribute if present, null otherwise. Attribute get(Identifier name) const; @@ -352,10 +361,10 @@ NamedAttrList &operator=(const SmallVectorImpl &rhs); operator ArrayRef() const; -private: /// Return whether the attributes are sorted. bool isSorted() const { return dictionarySorted.getInt(); } +private: /// Erase the attribute at the given iterator position. Attribute eraseImpl(SmallVectorImpl::iterator it); @@ -464,7 +473,7 @@ OperationState(Location location, OperationName name); OperationState(Location location, StringRef name, ValueRange operands, - TypeRange types, ArrayRef attributes, + TypeRange types, AttributeRange attributes, BlockRange successors = {}, MutableArrayRef> regions = {}); @@ -490,7 +499,7 @@ } /// Add an array of named attributes. - void addAttributes(ArrayRef newAttributes) { + void addAttributes(AttributeRange newAttributes) { attributes.append(newAttributes); } @@ -515,6 +524,222 @@ MLIRContext *getContext() const { return location->getContext(); } }; +//===----------------------------------------------------------------------===// +// OpAttributeList +//===----------------------------------------------------------------------===// + +namespace detail { +class AttributeAllocator : public std::allocator { +public: + using size_type = unsigned; + using difference_type = int; +}; + +class OpAttributeStorage { +public: + using size_type = unsigned; + using difference_type = int; + using iterator = NamedAttribute *; + using const_iterator = const NamedAttribute *; + +private: + /// The provided storage is inline with the operation, and for "small" lists, + /// will contain exactly enough space to contain the attributes [first, last). + /// For "large" lists, no inline storage is provided. + OpAttributeStorage(MutableArrayRef initialStorage, + const_iterator first, const_iterator last) + : ptr(initialStorage.data()), size(0), capacity(initialStorage.size()), + isHeap(false) { + assign(first, last); + } + /// Free the memory if it came from the heap. + ~OpAttributeStorage() { + if (isHeap) + free(ptr); + } + + operator ArrayRef() const { return {ptr, size}; } + + iterator begin() { return ptr; } + iterator end() { return begin() + size; } + const_iterator begin() const { return ptr; } + const_iterator end() const { return begin() + size; } + + void erase(iterator it) { + size_type newSize = size - 1; + for (auto *e = end() - 1; it != e; ++it) + *it = *(it + 1); + size = newSize; + } + + void insert(iterator it, NamedAttribute value) { + size_type newSize = size + 1; + if (newSize > capacity) { + capacity = 2u * (capacity + 1u); + size_type memSize = capacity * sizeof(NamedAttribute); + size_type index = it - ptr; + if (isHeap) { + ptr = reinterpret_cast(realloc(ptr, memSize)); + it = ptr + index; + for (auto *e = end(); e != it; --e) + *e = *(e - 1); + *it = value; + } else { + auto *newPtr = reinterpret_cast(malloc(memSize)); + pod_copy(ptr, newPtr, index); + *(it = newPtr + index) = value; + pod_copy(ptr + index, it + 1, size - index); + ptr = newPtr; + } + } else { + for (auto *e = end(); e != it; --e) + *e = *(e - 1); + *it = value; + } + size = newSize; + } + + void assign(const_iterator first, const_iterator last) { + size_type newSize = last - first; + if (newSize > capacity) { + capacity = std::max(newSize, 2u * (capacity + 1u)); + size_t memSize = capacity * sizeof(NamedAttribute); + ptr = reinterpret_cast(isHeap ? realloc(ptr, memSize) + : malloc(memSize)); + } + pod_copy(first, ptr, newSize); + size = newSize; + } + + static void pod_copy(const_iterator first, iterator dest, size_type num) { + memcpy(dest, first, num * sizeof(NamedAttribute)); + } + + /// Store whether the pointer is from dynamic storage. + NamedAttribute *ptr; + size_type size; + size_type capacity; + bool isHeap; + + friend class OpAttributeList; +}; + +class OpAttributeList { + using Storage = std::vector; + // using Storage = SmallVector; + // using Storage = OpAttributeStorage; + // using Storage = std::vector; + +public: + using iterator = Storage::iterator; + using const_iterator = Storage::const_iterator; + + OpAttributeList(MutableArrayRef initialStorage, + AttributeRange attributes); + + Attribute get(Identifier name) const { + auto *ptr = impl::lookupAttributeSpecial(attrs, name); + return ptr ? ptr->second : Attribute(); + } + Attribute get(StringRef name) const { + auto it = lookupAttribute(name); + return it.second ? it.first->second : Attribute(); + } + Attribute get(Identifier name, unsigned left, unsigned right) const { + auto *ptr = impl::lookupAttributeSubrange(attrs, name, left, right); + // if (LLVM_UNLIKELY(!ptr)) + // ptr = impl::lookupAttributeSpecial(attrs, name); + return ptr ? ptr->second : Attribute(); + } + + Optional getNamed(Identifier name) const { + auto *ptr = impl::lookupAttributeSpecial(attrs, name); + return ptr ? *ptr : Optional(); + } + Optional getNamed(StringRef name) const { + auto it = lookupAttribute(name); + return it.second ? *it.first : Optional(); + } + Optional getNamed(Identifier name, unsigned left, + unsigned right) const { + auto *ptr = impl::lookupAttributeSubrange(attrs, name, left, right); + return ptr ? *ptr : Optional(); + } + + bool has(Identifier name) const { + return impl::lookupAttributeSpecial(attrs, name); + } + bool has(StringRef name) const { return lookupAttribute(name).second; } + bool has(Identifier name, unsigned left, unsigned right) const { + return impl::lookupAttributeSubrange(attrs, name, left, right); + } + + template + Attribute erase(NameT name, unsigned left = 0, unsigned right = 0) { + auto it = lookupAttribute(name, left, right); + if (!it.second) + return Attribute(); + auto value = it.first->second; + attrs.erase(it.first); + return value; + } + + void set(StringRef name, Attribute value); + void set(Identifier name, Attribute value, unsigned left = 0, + unsigned right = 0); + + /// Add an attribute when it is known that the attribute is more likely to not + /// already be in the list. This method goes straight to string lookup, which + /// is faster than set(Identifier, Attribute) when adding a new attribute. + void addOrSet(Identifier name, Attribute value); + /// Set an attribute when it is known that the attribute is more like to + /// already be in the list. This method skips the string lookup, which *may* + /// be faster than set(StringRef, Attribute) when overwriting an attribute. + void setOrAdd(StringRef name, Attribute value); + + void definitelySet(Identifier name, Attribute value, unsigned left, + unsigned right); + template < + typename MaterializeValueFn, + typename AttributeT = decltype(std::declval()())> + AttributeT getOrSet(Identifier name, MaterializeValueFn &&materialize, + unsigned left, unsigned right) { + auto it = lookupAttribute(name.strref(), left, right); + if (it.second) + return it.first->second.cast(); + AttributeT value = materialize(); + // attrs.insert(it.first, {name, value}); + return value; + } + + void assign(AttributeRange attributes); + + /// Conversions to other attribute ranges and containers. + operator ArrayRef() const { return attrs; } + DictionaryAttr convertToDictionary(MLIRContext *context) const; + +public: + // Search a subset of the list when it is known that an attribute will have + // `left` attributes to its left and `right` attributes to its right based on + // sortedness. + template + std::pair lookupAttribute(NameT name, unsigned left = 0, + unsigned right = 0) { + return impl::lookupAttributeSorted(attrs.begin() + left, + attrs.end() - right, name); + } + template + std::pair lookupAttribute(NameT name, unsigned left = 0, + unsigned right = 0) const { + return impl::lookupAttributeSorted(attrs.begin() + left, + attrs.end() - right, name); + } + + /// The internal storage of attributes inside an operation. + Storage attrs; +}; +} // end namespace detail + //===----------------------------------------------------------------------===// // OperandStorage //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -233,11 +233,11 @@ LogicalResult inferReturnTensorTypes( function_ref location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &retComponents)> componentTypeFn, MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes); /// Verifies that the inferred result types match the actual result types for @@ -269,7 +269,7 @@ public: static LogicalResult inferReturnTypes(MLIRContext *context, Optional location, - ValueRange operands, DictionaryAttr attributes, + ValueRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { static_assert( diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -43,7 +43,7 @@ /*args=*/(ins "::mlir::MLIRContext *":$context, "::llvm::Optional<::mlir::Location>":$location, "::mlir::ValueRange":$operands, - "::mlir::DictionaryAttr":$attributes, + "::mlir::AttributeRange":$attributes, "::mlir::RegionRange":$regions, "::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes) >, @@ -96,7 +96,7 @@ /*args=*/(ins "::mlir::MLIRContext*":$context, "::mlir::Optional<::mlir::Location>":$location, "::mlir::ValueShapeRange":$operands, - "::mlir::DictionaryAttr":$attributes, + "::mlir::AttributeRange":$attributes, "::mlir::RegionRange":$regions, "::mlir::SmallVectorImpl<::mlir::ShapedTypeComponents>&": $inferredReturnShapes), diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -94,12 +94,22 @@ public: /// Copy the specified array of elements into memory managed by our bump /// pointer allocator. This assumes the elements are all PODs. - template ArrayRef copyInto(ArrayRef elements) { - if (elements.empty()) + template ()))>>> + ArrayRef copyInto(RangeT elements) { + return copyInto(std::begin(elements), std::end(elements)); + } + + template ())>>> + ArrayRef copyInto(IteratorT first, IteratorT last) { + if (first == last) return llvm::None; - auto result = allocator.Allocate(elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), result); - return ArrayRef(result, elements.size()); + auto *result = allocator.Allocate(std::distance(first, last)); + std::uninitialized_copy(first, last, result); + return ArrayRef(result, std::distance(first, last)); } /// Copy the provided string into memory managed by our bump pointer diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -353,7 +353,7 @@ /// Construct a conversion pattern with the given converter, and forward the /// remaining arguments to RewritePattern. template - ConversionPattern(TypeConverter &typeConverter, Args &&... args) + ConversionPattern(TypeConverter &typeConverter, Args &&...args) : RewritePattern(std::forward(args)...), typeConverter(&typeConverter) {} @@ -384,14 +384,14 @@ /// type. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), - rewriter); + rewrite(cast(op), + OpAdaptor(getContext(), operands, op->getAttrs()), rewriter); } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary()), + OpAdaptor(getContext(), operands, op->getAttrs()), rewriter); } @@ -408,7 +408,7 @@ ConversionPatternRewriter &rewriter) const { if (failed(match(op))) return failure(); - rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter); + rewrite(op, OpAdaptor(getContext(), operands, op->getAttrs()), rewriter); return success(); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -286,8 +286,8 @@ } if (succeeded(inferInterface->inferReturnTypes( - context, state.location, state.operands, - state.attributes.getDictionary(context), state.regions, state.types))) + context, state.location, state.operands, state.attributes, + state.regions, state.types))) return success(); // Diagnostic emitted by interface. diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp --- a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp +++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp @@ -180,7 +180,7 @@ return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { - OpAdaptor adaptor(operands); + OpAdaptor adaptor(getContext(), operands); return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), @@ -217,7 +217,7 @@ return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { - OpAdaptor adaptor(operands); + OpAdaptor adaptor(getContext(), operands); return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -372,7 +372,7 @@ ValueRange(coroSize.getResult())); // Begin a coroutine: @llvm.coro.begin. - auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); + auto coroId = CoroBeginOpAdaptor(getContext(), adaptor.getOperands()).id(); rewriter.replaceOpWithNewOp( op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -316,16 +316,14 @@ ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.source().getType(); if (operandType.isa()) { - rewriter.replaceOp( - dimOp, {extractSizeOfUnrankedMemRef( - operandType, dimOp, adaptor.getOperands(), rewriter)}); + rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef( + operandType, dimOp, adaptor, rewriter)}); return success(); } if (operandType.isa()) { - rewriter.replaceOp( - dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, - adaptor.getOperands(), rewriter)}); + rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, + adaptor, rewriter)}); return success(); } llvm_unreachable("expected MemRefType or UnrankedMemRefType"); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -228,12 +228,12 @@ if (!dstType) return failure(); rewriter.replaceOpWithNewOp( - loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment, - isVolatile, isNonTemporal); + loadOp, dstType, spirv::LoadOpAdaptor(op->getContext(), operands).ptr(), + alignment, isVolatile, isNonTemporal); return success(); } auto storeOp = cast(op); - spirv::StoreOpAdaptor adaptor(operands); + spirv::StoreOpAdaptor adaptor(op->getContext(), operands); rewriter.replaceOpWithNewOp(storeOp, adaptor.value(), adaptor.ptr(), alignment, isVolatile, isNonTemporal); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -49,8 +49,7 @@ /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. -static void filterFuncAttributes(ArrayRef attrs, - bool filterArgAttrs, +static void filterFuncAttributes(AttributeRange attrs, bool filterArgAttrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.first == SymbolTable::getSymbolAttrName() || diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -44,7 +44,8 @@ LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, Value &glc, Value &slc) { - auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); + auto adaptor = TransferWriteOpAdaptor(rewriter.getContext(), operands, + xferOp->getAttrs()); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dwordConfig, vindex, offsetSizeInBytes, glc, slc); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -280,7 +280,7 @@ //===----------------------------------------------------------------------===// void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, - ArrayRef attrs) { + AttributeRange attrs) { result.addOperands({operand}); result.attributes.append(attrs.begin(), attrs.end()); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -253,7 +253,7 @@ auto outputTypes = execute.getResultTypes(); auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); - auto funcAttrs = ArrayRef(); + auto funcAttrs = AttributeRange(); // TODO: Derive outlined function name from the parent FuncOp (support // multiple nested async.execute operations). diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -714,8 +714,7 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name, FunctionType type, TypeRange workgroupAttributions, - TypeRange privateAttributions, - ArrayRef attrs) { + TypeRange privateAttributions, AttributeRange attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -106,7 +106,7 @@ copy(op->getResultTypes(), std::back_inserter(resultTypes)); resultTypes.push_back(tokenType); auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes, - op->getOperands(), op->getAttrDictionary(), + op->getOperands(), op->getAttrs(), op->getSuccessors(), op->getNumRegions()); // Clone regions into new op. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -48,7 +48,7 @@ #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" -static auto processFMFAttr(ArrayRef attrs) { +static auto processFMFAttr(AttributeRange attrs) { SmallVector filteredAttrs( llvm::make_filter_range(attrs, [&](NamedAttribute attr) { if (attr.first == "fastmathFlags") { @@ -66,8 +66,8 @@ } static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, - DictionaryAttr attrs) { - printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); + AttributeRange attrs) { + printer.printOptionalAttrDict(processFMFAttr(attrs)); } /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and @@ -971,7 +971,7 @@ // wrapped LLVM i32 type. void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, Value vector, Value position, - ArrayRef attrs) { + AttributeRange attrs) { auto vectorType = vector.getType(); auto llvmType = LLVM::getVectorElementType(vectorType); build(b, result, llvmType, vector, position); @@ -1400,7 +1400,7 @@ void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, bool isConstant, Linkage linkage, StringRef name, Attribute value, uint64_t alignment, unsigned addrSpace, - bool dsoLocal, ArrayRef attrs) { + bool dsoLocal, AttributeRange attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); result.addAttribute("global_type", TypeAttr::get(type)); @@ -1689,7 +1689,7 @@ // wrapped LLVM i32 type. void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, Value v1, Value v2, ArrayAttr mask, - ArrayRef attrs) { + AttributeRange attrs) { auto containerType = v1.getType(); auto vType = LLVM::getFixedVectorType( LLVM::getVectorElementType(containerType), mask.size()); @@ -1750,7 +1750,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name, Type type, LLVM::Linkage linkage, - bool dsoLocal, ArrayRef attrs, + bool dsoLocal, AttributeRange attrs, ArrayRef argAttrs) { result.addRegion(); result.addAttribute(SymbolTable::getSymbolAttrName(), diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -372,8 +372,7 @@ void CopyOp::build(OpBuilder &builder, OperationState &result, Value input, Value output, AffineMap inputPermutation, - AffineMap outputPermutation, - ArrayRef namedAttrs) { + AffineMap outputPermutation, AttributeRange namedAttrs) { result.addOperands({input, output}); result.addAttributes(namedAttrs); if (inputPermutation) @@ -518,7 +517,7 @@ ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild, - ArrayRef attributes) { + AttributeRange attributes) { build(builder, result, resultTensorTypes, inputs, outputs, builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), @@ -545,7 +544,7 @@ ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild, - ArrayRef attributes) { + AttributeRange attributes) { build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild, attributes); } @@ -555,7 +554,7 @@ ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild, - ArrayRef attributes) { + AttributeRange attributes) { build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", /*libraryCall=*/"", bodyBuild, attributes); @@ -566,7 +565,7 @@ ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild, - ArrayRef attributes) { + AttributeRange attributes) { build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", @@ -837,7 +836,7 @@ //===----------------------------------------------------------------------===// void InitTensorOp::build(OpBuilder &b, OperationState &result, ArrayRef sizes, Type elementType, - ArrayRef attrs) { + AttributeRange attrs) { unsigned rank = sizes.size(); SmallVector dynamicSizes; SmallVector staticSizes; @@ -1102,8 +1101,7 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef staticLow, ArrayRef staticHigh, ValueRange low, - ValueRange high, bool nofold, - ArrayRef attrs) { + ValueRange high, bool nofold, AttributeRange attrs) { auto sourceType = source.getType().cast(); auto resultType = inferResultType(sourceType, staticLow, staticHigh); build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), @@ -1113,7 +1111,7 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange low, ValueRange high, bool nofold, - ArrayRef attrs) { + AttributeRange attrs) { auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); SmallVector staticVector(rank, ShapedType::kDynamicSize); @@ -1124,7 +1122,7 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType, Value source, ArrayRef low, ArrayRef high, bool nofold, - ArrayRef attrs) { + AttributeRange attrs) { assert(resultType.isa()); auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); @@ -1776,8 +1774,7 @@ void mlir::linalg::TensorCollapseShapeOp::build( OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { + ArrayRef reassociation, AttributeRange attrs) { auto resultType = computeTensorReshapeCollapsedType( src.getType().cast(), getSymbolLessAffineMaps( @@ -1789,8 +1786,7 @@ void mlir::linalg::TensorExpandShapeOp::build( OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { + ArrayRef reassociation, AttributeRange attrs) { auto resultType = computeTensorReshapeCollapsedType( src.getType().cast(), getSymbolLessAffineMaps( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -173,7 +173,7 @@ // We abuse the GenericOpAdaptor here. // TODO: Manually create an Adaptor that captures inputs and outputs for all // linalg::LinalgOp interface ops. - linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); + linalg::GenericOpAdaptor adaptor(getContext(), operands, op->getAttrs()); Location loc = op.getLoc(); SmallVector newOutputBuffers; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1338,7 +1338,7 @@ MemRefType resultType, Value source, OpFoldResult offset, ArrayRef sizes, ArrayRef strides, - ArrayRef attrs) { + AttributeRange attrs) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, @@ -1356,8 +1356,7 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, int64_t offset, ArrayRef sizes, - ArrayRef strides, - ArrayRef attrs) { + ArrayRef strides, AttributeRange attrs) { SmallVector sizeValues = llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); @@ -1373,7 +1372,7 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, Value offset, ValueRange sizes, ValueRange strides, - ArrayRef attrs) { + AttributeRange attrs) { SmallVector sizeValues = llvm::to_vector<4>( llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); SmallVector strideValues = llvm::to_vector<4>( @@ -1552,7 +1551,7 @@ void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, - ArrayRef attrs) { + AttributeRange attrs) { auto memRefType = src.getType().cast(); auto resultType = computeReshapeCollapsedType( memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( @@ -1564,7 +1563,7 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, - ArrayRef attrs) { + AttributeRange attrs) { auto memRefType = src.getType().cast(); auto resultType = computeReshapeCollapsedType( memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( @@ -1858,8 +1857,7 @@ MemRefType resultType, Value source, ArrayRef offsets, ArrayRef sizes, - ArrayRef strides, - ArrayRef attrs) { + ArrayRef strides, AttributeRange attrs) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, @@ -1886,16 +1884,14 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef offsets, ArrayRef sizes, - ArrayRef strides, - ArrayRef attrs) { + ArrayRef strides, AttributeRange attrs) { build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); } // Build a SubViewOp with static entries and inferred result type. void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef offsets, ArrayRef sizes, - ArrayRef strides, - ArrayRef attrs) { + ArrayRef strides, AttributeRange attrs) { SmallVector offsetValues = llvm::to_vector<4>( llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); @@ -1916,8 +1912,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ArrayRef offsets, ArrayRef sizes, - ArrayRef strides, - ArrayRef attrs) { + ArrayRef strides, AttributeRange attrs) { SmallVector offsetValues = llvm::to_vector<4>( llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); @@ -1939,7 +1934,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, - ArrayRef attrs) { + AttributeRange attrs) { SmallVector offsetValues = llvm::to_vector<4>( llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); SmallVector sizeValues = llvm::to_vector<4>( @@ -1952,7 +1947,7 @@ // Build a SubViewOp with dynamic entries and inferred result type. void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, - ArrayRef attrs) { + AttributeRange attrs) { build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); } @@ -2285,8 +2280,7 @@ } void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, - AffineMapAttr permutation, - ArrayRef attrs) { + AffineMapAttr permutation, AttributeRange attrs) { auto permutationMap = permutation.getValue(); assert(permutationMap); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -58,7 +58,7 @@ //===----------------------------------------------------------------------===// void ParallelOp::build(OpBuilder &builder, OperationState &state, - ArrayRef attributes) { + AttributeRange attributes) { ParallelOp::build( builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, /*default_val=*/nullptr, /*private_vars=*/ValueRange(), @@ -1061,7 +1061,7 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &state, ValueRange lowerBound, ValueRange upperBound, - ValueRange step, ArrayRef attributes) { + ValueRange step, AttributeRange attributes) { build(builder, state, TypeRange(), lowerBound, upperBound, step, /*private_vars=*/ValueRange(), /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), @@ -1074,7 +1074,7 @@ } void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, - ValueRange operands, ArrayRef attributes) { + ValueRange operands, AttributeRange attributes) { state.addOperands(operands); state.addAttributes(attributes); (void)state.addRegion(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -354,7 +354,7 @@ } bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { - return lhs->getAttrDictionary() == rhs->getAttrDictionary(); + return lhs->getAttrs() == rhs->getAttrs(); } // Returns a source value for the given block. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1984,7 +1984,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, FunctionType type, spirv::FunctionControl control, - ArrayRef attrs) { + AttributeRange attrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -423,7 +423,7 @@ LogicalResult mlir::shape::AddOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType().isa() || operands[1].getType().isa()) @@ -821,7 +821,7 @@ LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { Builder b(context); auto shape = attributes.getAs("shape"); @@ -999,7 +999,7 @@ LogicalResult mlir::shape::DivOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType().isa() || operands[1].getType().isa()) @@ -1157,7 +1157,7 @@ LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.assign({IndexType::get(context)}); return success(); @@ -1193,7 +1193,7 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.assign({operands[0].getType()}); return success(); @@ -1281,7 +1281,7 @@ LogicalResult mlir::shape::RankOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType().isa()) inferredReturnTypes.assign({SizeType::get(context)}); @@ -1315,7 +1315,7 @@ LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType().isa()) inferredReturnTypes.assign({SizeType::get(context)}); @@ -1343,7 +1343,7 @@ LogicalResult mlir::shape::MaxOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType() == operands[1].getType()) inferredReturnTypes.assign({operands[0].getType()}); @@ -1375,7 +1375,7 @@ LogicalResult mlir::shape::MinOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType() == operands[1].getType()) inferredReturnTypes.assign({operands[0].getType()}); @@ -1412,7 +1412,7 @@ LogicalResult mlir::shape::MulOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType().isa() || operands[1].getType().isa()) @@ -1496,7 +1496,7 @@ LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType().isa()) inferredReturnTypes.assign({ShapeType::get(context)}); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -742,7 +742,7 @@ ArrayRef offsets, ArrayRef sizes, ArrayRef strides, - ArrayRef attrs) { + AttributeRange attrs) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, @@ -771,7 +771,7 @@ ArrayRef offsets, ArrayRef sizes, ArrayRef strides, - ArrayRef attrs) { + AttributeRange attrs) { build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); } @@ -780,7 +780,7 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, RankedTensorType resultType, Value source, ValueRange offsets, ValueRange sizes, - ValueRange strides, ArrayRef attrs) { + ValueRange strides, AttributeRange attrs) { SmallVector offsetValues = llvm::to_vector<4>( llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); SmallVector sizeValues = llvm::to_vector<4>( @@ -793,7 +793,7 @@ /// Build an ExtractSliceOp with dynamic entries and inferred result type. void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange offsets, ValueRange sizes, - ValueRange strides, ArrayRef attrs) { + ValueRange strides, AttributeRange attrs) { build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); } @@ -1079,7 +1079,7 @@ Value dest, ArrayRef offsets, ArrayRef sizes, ArrayRef strides, - ArrayRef attrs) { + AttributeRange attrs) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, @@ -1097,7 +1097,7 @@ // Build a InsertSliceOp with dynamic entries. void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, ValueRange offsets, ValueRange sizes, - ValueRange strides, ArrayRef attrs) { + ValueRange strides, AttributeRange attrs) { SmallVector offsetValues = llvm::to_vector<4>( llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); SmallVector sizeValues = llvm::to_vector<4>( diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -589,7 +589,7 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); IntegerAttr axis = attributes.get("axis").cast(); @@ -614,7 +614,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. int32_t axis = @@ -671,7 +671,7 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); ShapeAdaptor weightShape = operands.getShape(1); @@ -701,7 +701,7 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor lhsShape = operands.getShape(0); ShapeAdaptor rhsShape = operands.getShape(1); @@ -728,7 +728,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); ShapeAdaptor paddingShape = operands.getShape(1); @@ -784,9 +784,9 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ArrayAttr sizes = SliceOpAdaptor(operands, attributes).size(); + ArrayAttr sizes = SliceOpAdaptor(context, operands, attributes).size(); SmallVector outputShape; outputShape.reserve(sizes.size()); for (auto val : sizes) { @@ -799,7 +799,7 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); @@ -815,9 +815,9 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - TileOpAdaptor adaptor(operands, attributes); + TileOpAdaptor adaptor(context, operands, attributes); ArrayAttr multiples = adaptor.multiples(); ShapeAdaptor inputShape = operands.getShape(0); SmallVector outputShape; @@ -849,9 +849,9 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ReshapeOpAdaptor adaptor(operands, attributes); + ReshapeOpAdaptor adaptor(context, operands, attributes); ShapeAdaptor inputShape = operands.getShape(0); ArrayAttr newShape = adaptor.new_shape(); @@ -888,7 +888,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); ShapeAdaptor permsShape = operands.getShape(1); @@ -955,7 +955,7 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamicSize); @@ -980,9 +980,9 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - ResizeOpAdaptor adaptor(operands, attributes); + ResizeOpAdaptor adaptor(context, operands, attributes); llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamicSize); @@ -1051,7 +1051,7 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamicSize); @@ -1100,7 +1100,7 @@ #define REDUCE_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::llvm::Optional location, \ - ValueShapeRange operands, DictionaryAttr attributes, \ + ValueShapeRange operands, AttributeRange attributes, \ RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ return ReduceInferReturnTypes(operands.getShape(0), \ @@ -1167,7 +1167,7 @@ #define NARY_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::llvm::Optional location, \ - ValueShapeRange operands, DictionaryAttr attributes, \ + ValueShapeRange operands, AttributeRange attributes, \ RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ return NAryInferReturnTypes(operands, inferredReturnShapes); \ @@ -1215,7 +1215,7 @@ #undef PRED_SHAPE_INFER static LogicalResult poolingInferReturnTypes( - const ValueShapeRange &operands, DictionaryAttr attributes, + const ValueShapeRange &operands, AttributeRange attributes, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); llvm::SmallVector outputShape; @@ -1258,10 +1258,10 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); - Conv2DOp::Adaptor adaptor(operands.getValues(), attributes); + Conv2DOp::Adaptor adaptor(context, operands.getValues(), attributes); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1323,10 +1323,10 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(5, ShapedType::kDynamicSize); - Conv2DOp::Adaptor adaptor(operands.getValues(), attributes); + Conv2DOp::Adaptor adaptor(context, operands.getValues(), attributes); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1399,24 +1399,24 @@ LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); } LogicalResult MaxPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); } LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); - DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes); + DepthwiseConv2DOp::Adaptor adaptor(context, operands.getValues(), attributes); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1491,9 +1491,9 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes); + TransposeConv2DOp::Adaptor adaptor(context, operands.getValues(), attributes); llvm::SmallVector outputShape; getI64Values(adaptor.out_shape(), outputShape); @@ -1559,7 +1559,7 @@ LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; for (Region *region : regions) { @@ -1603,7 +1603,7 @@ LogicalResult WhileOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; for (auto &block : *regions[1]) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -213,8 +213,8 @@ ValueShapeRange range(op.getOperands(), operandShape); if (shapeInterface .inferReturnTypeComponents(op.getContext(), op.getLoc(), range, - op.getAttrDictionary(), - op.getRegions(), returnedShapes) + op.getAttrs(), op.getRegions(), + returnedShapes) .succeeded()) { for (auto it : llvm::zip(op.getResults(), returnedShapes)) { Value result = std::get<0>(it); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -460,7 +460,7 @@ print(&b); } - void printRegionArgument(BlockArgument arg, ArrayRef argAttrs, + void printRegionArgument(BlockArgument arg, AttributeRange argAttrs, bool omitType) override { printType(arg.getType()); // Visit the argument location. @@ -480,7 +480,7 @@ /// Print the given set of attributes with names not included within /// 'elidedAttrs'. - void printOptionalAttrDict(ArrayRef attrs, + void printOptionalAttrDict(AttributeRange attrs, ArrayRef elidedAttrs = {}) override { if (attrs.empty()) return; @@ -496,8 +496,7 @@ printAttribute(attr.second); } void printOptionalAttrDictWithKeyword( - ArrayRef attrs, - ArrayRef elidedAttrs = {}) override { + AttributeRange attrs, ArrayRef elidedAttrs = {}) override { printOptionalAttrDict(attrs, elidedAttrs); } @@ -1269,7 +1268,7 @@ void printIntegerSet(IntegerSet set); protected: - void printOptionalAttrDict(ArrayRef attrs, + void printOptionalAttrDict(AttributeRange attrs, ArrayRef elidedAttrs = {}, bool withKeyword = false); void printNamedAttribute(NamedAttribute attr); @@ -2007,7 +2006,7 @@ .Default([&](Type type) { return printDialectType(type); }); } -void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef attrs, +void AsmPrinter::Impl::printOptionalAttrDict(AttributeRange attrs, ArrayRef elidedAttrs, bool withKeyword) { // If there are no attributes, then there is nothing to be done. @@ -2385,8 +2384,7 @@ /// where location printing is controlled by the standard internal option. /// You may pass omitType=true to not print a type, and pass an empty /// attribute list if you don't care for attributes. - void printRegionArgument(BlockArgument arg, - ArrayRef argAttrs = {}, + void printRegionArgument(BlockArgument arg, AttributeRange argAttrs = {}, bool omitType = false) override; /// Print the ID for the given value. @@ -2396,13 +2394,12 @@ } /// Print an optional attribute dictionary with a given set of elided values. - void printOptionalAttrDict(ArrayRef attrs, + void printOptionalAttrDict(AttributeRange attrs, ArrayRef elidedAttrs = {}) override { Impl::printOptionalAttrDict(attrs, elidedAttrs); } void printOptionalAttrDictWithKeyword( - ArrayRef attrs, - ArrayRef elidedAttrs = {}) override { + AttributeRange attrs, ArrayRef elidedAttrs = {}) override { Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/true); } @@ -2471,7 +2468,7 @@ /// You may pass omitType=true to not print a type, and pass an empty /// attribute list if you don't care for attributes. void OperationPrinter::printRegionArgument(BlockArgument arg, - ArrayRef argAttrs, + AttributeRange argAttrs, bool omitType) { printOperand(arg); if (!omitType) { diff --git a/mlir/lib/IR/AttributeRange.cpp b/mlir/lib/IR/AttributeRange.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/AttributeRange.cpp @@ -0,0 +1,42 @@ +#include "mlir/IR/AttributeRange.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OperationSupport.h" + +using namespace mlir; + +AttributeRange::AttributeRange(const NamedAttrList &attrs) + : Base(attrs), sorted(attrs.isSorted()) {} + +AttributeRange::AttributeRange(DictionaryAttr attrs) + : AttributeRange(attrs ? attrs.getValue() : AttributeRange({}, true)) { + assert(!attrs || + isSorted() && "expected dictionary attribute sorted flag to be true"); +} + +AttributeRange::AttributeRange(const detail::OpAttributeList &attrs) + : Base(attrs), sorted(true) {} + +const NamedAttribute * +impl::detail::lookupAttributeBig(ArrayRef attrs, + Identifier name) { + unsigned length = attrs.size(); + const NamedAttribute *first = attrs.begin(); + StringRef nameStr = name.strref(); + + // TODO: I wrote this binary search twice + while (length > 0) { + unsigned half = length / 2; + const NamedAttribute *mid = first + half; + auto c = mid->first.strref().compare(nameStr); + if (c < 0) { + first = mid; + ++first; + length = length - half - 1; + } else if (c > 0) { + length = half; + } else { + return mid; + } + } + return nullptr; +} diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -92,7 +92,7 @@ return BoolAttr::get(context, value); } -DictionaryAttr Builder::getDictionaryAttr(ArrayRef value) { +DictionaryAttr Builder::getDictionaryAttr(AttributeRange value) { return DictionaryAttr::get(context, value); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -62,7 +62,7 @@ /// destination, else value is the source and storage destination. Returns /// whether source was sorted. template -static bool dictionaryAttrSort(ArrayRef value, +static bool dictionaryAttrSort(AttributeRange value, SmallVectorImpl &storage) { // Specialize for the common case. switch (value.size()) { @@ -103,8 +103,7 @@ /// Returns an entry with a duplicate name from the given sorted array of named /// attributes. Returns llvm::None if all elements have unique names. -static Optional -findDuplicateElement(ArrayRef value) { +static Optional findDuplicateElement(AttributeRange value) { const Optional none{llvm::None}; if (value.size() < 2) return none; @@ -118,7 +117,7 @@ return it != value.end() ? *it : none; } -bool DictionaryAttr::sort(ArrayRef value, +bool DictionaryAttr::sort(AttributeRange value, SmallVectorImpl &storage) { bool isSorted = dictionaryAttrSort(value, storage); assert(!findDuplicateElement(storage) && @@ -141,8 +140,7 @@ return findDuplicateElement(array); } -DictionaryAttr DictionaryAttr::get(MLIRContext *context, - ArrayRef value) { +DictionaryAttr DictionaryAttr::get(MLIRContext *context, AttributeRange value) { if (value.empty()) return DictionaryAttr::getEmpty(context); assert(llvm::all_of(value, @@ -160,7 +158,7 @@ /// Construct a dictionary with an array of values that is known to already be /// sorted by name and uniqued. DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, - ArrayRef value) { + AttributeRange value) { if (value.empty()) return DictionaryAttr::getEmpty(context); // Ensure that the attribute elements are unique and sorted. @@ -186,8 +184,8 @@ /// Return the specified named attribute if present, None otherwise. Optional DictionaryAttr::getNamed(StringRef name) const { - ArrayRef values = getValue(); - const auto *it = llvm::lower_bound(values, name); + AttributeRange values = getValue(); + auto it = llvm::lower_bound(values, name); return it != values.end() && it->first == name ? *it : Optional(); } @@ -199,15 +197,15 @@ } DictionaryAttr::iterator DictionaryAttr::begin() const { - return getValue().begin(); + return AttributeRange(getValue()).begin(); } DictionaryAttr::iterator DictionaryAttr::end() const { - return getValue().end(); + return AttributeRange(getValue()).end(); } size_t DictionaryAttr::size() const { return getValue().size(); } DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { - return Base::get(context, ArrayRef()); + return Base::get(context, AttributeRange()); } void DictionaryAttr::walkImmediateSubElements( diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -77,7 +77,7 @@ //===----------------------------------------------------------------------===// FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, - ArrayRef attrs) { + AttributeRange attrs) { OpBuilder builder(location->getContext()); OperationState state(location, getOperationName()); FuncOp::build(builder, state, name, type, attrs); @@ -86,18 +86,17 @@ FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs) { SmallVector attrRef(attrs); - return create(location, name, type, llvm::makeArrayRef(attrRef)); + return create(location, name, type, AttributeRange(attrRef)); } FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, - ArrayRef attrs, - ArrayRef argAttrs) { + AttributeRange attrs, ArrayRef argAttrs) { FuncOp func = create(location, name, type, attrs); func.setAllArgAttrs(argAttrs); return func; } void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, - FunctionType type, ArrayRef attrs, + FunctionType type, AttributeRange attrs, ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -2,6 +2,7 @@ AffineExpr.cpp AffineMap.cpp AsmPrinter.cpp + AttributeRange.cpp Attributes.cpp Block.cpp Builders.cpp diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -305,7 +305,7 @@ p << ", "; if (!isExternal) { - ArrayRef attrs; + AttributeRange attrs; if (argAttrs) attrs = argAttrs[i].cast().getValue(); p.printRegionArgument(body.getArgument(i), attrs); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -66,27 +66,16 @@ // Operation //===----------------------------------------------------------------------===// -/// Create a new Operation with the specific fields. -Operation *Operation::create(Location location, OperationName name, - TypeRange resultTypes, ValueRange operands, - ArrayRef attributes, - BlockRange successors, unsigned numRegions) { - return create(location, name, resultTypes, operands, - DictionaryAttr::get(location.getContext(), attributes), - successors, numRegions); -} - /// Create a new Operation from operation state. Operation *Operation::create(const OperationState &state) { return create(state.location, state.name, state.types, state.operands, - state.attributes.getDictionary(state.getContext()), - state.successors, state.regions); + state.attributes, state.successors, state.regions); } /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, - DictionaryAttr attributes, BlockRange successors, + AttributeRange attributes, BlockRange successors, RegionRange regions) { unsigned numRegions = regions.size(); Operation *op = create(location, name, resultTypes, operands, attributes, @@ -97,11 +86,10 @@ return op; } -/// Overload of create that takes an existing DictionaryAttr to avoid -/// unnecessarily uniquing a list of attributes. +/// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, - DictionaryAttr attributes, BlockRange successors, + AttributeRange attributes, BlockRange successors, unsigned numRegions) { assert(llvm::all_of(resultTypes, [](Type t) { return t; }) && "unexpected null result type"); @@ -121,13 +109,18 @@ needsOperandStorage = !abstractOp->hasTrait(); } + // Determine whether the attribute list is too "large" to fit inline. + unsigned numInlineAttrs = 0; + // attributes.size() <= kMaxInlineAttrs ? attributes.size() : 0; + // Compute the byte size for the operation and the operand storage. This takes // into account the size of the operation, its trailing objects, and its // prefixed objects. - size_t byteSize = - totalSizeToAlloc( - numSuccessors, numRegions, needsOperandStorage ? 1 : 0) + - detail::OperandStorage::additionalAllocSize(numOperands); + size_t byteSize = totalSizeToAlloc( + numSuccessors, numRegions, numInlineAttrs, + needsOperandStorage ? 1 : 0) + + detail::OperandStorage::additionalAllocSize(numOperands); size_t prefixByteSize = llvm::alignTo( Operation::prefixAllocSize(numTrailingResults, numInlineResults), alignof(Operation)); @@ -135,9 +128,9 @@ void *rawMem = mallocMem + prefixByteSize; // Create the new Operation. - Operation *op = - ::new (rawMem) Operation(location, name, numResults, numSuccessors, - numRegions, attributes, needsOperandStorage); + Operation *op = ::new (rawMem) + Operation(location, name, numResults, numSuccessors, numRegions, + numInlineAttrs, attributes, needsOperandStorage); assert((numSuccessors == 0 || op->mightHaveTrait()) && "unexpected successors in a non-terminator operation"); @@ -169,11 +162,12 @@ Operation::Operation(Location location, OperationName name, unsigned numResults, unsigned numSuccessors, unsigned numRegions, - DictionaryAttr attributes, bool hasOperandStorage) + unsigned numInlineAttributes, AttributeRange attributes, + bool hasOperandStorage) : location(location), numResults(numResults), numSuccs(numSuccessors), - numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), - attrs(attributes) { - assert(attributes && "unexpected null attribute dictionary"); + numRegions(numRegions), numInlineAttrs(numInlineAttributes), + hasOperandStorage(hasOperandStorage), name(name), + attrs(getInlineAttrStorage(), attributes) { #ifndef NDEBUG if (!getDialect() && !getContext()->allowsUnregisteredDialects()) llvm::report_fatal_error( @@ -1049,11 +1043,11 @@ continue; if (region.getNumArguments() != 0) { - if (op->getNumRegions() > 1) + if (op->getNumRegions() > 1) { return op->emitOpError("region #") << region.getRegionNumber() << " should have no arguments"; - else - return op->emitOpError("region should have no arguments"); + } + return op->emitOpError("region should have no arguments"); } } return success(); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -24,13 +24,10 @@ // NamedAttrList //===----------------------------------------------------------------------===// -NamedAttrList::NamedAttrList(ArrayRef attributes) { - assign(attributes.begin(), attributes.end()); -} +NamedAttrList::NamedAttrList(AttributeRange attributes) { assign(attributes); } NamedAttrList::NamedAttrList(DictionaryAttr attributes) - : NamedAttrList(attributes ? attributes.getValue() - : ArrayRef()) { + : NamedAttrList(attributes ? attributes.getValue() : AttributeRange()) { dictionarySorted.setPointerAndInt(attributes, true); } @@ -38,7 +35,7 @@ assign(in_start, in_end); } -ArrayRef NamedAttrList::getAttrs() const { return attrs; } +AttributeRange NamedAttrList::getAttrs() const { return attrs; } Optional NamedAttrList::findDuplicate() const { Optional duplicate = @@ -175,6 +172,95 @@ NamedAttrList::operator ArrayRef() const { return attrs; } +//===----------------------------------------------------------------------===// +// OpAttributeList +//===----------------------------------------------------------------------===// + +detail::OpAttributeList::OpAttributeList( + MutableArrayRef initialStorage, AttributeRange attributes) + : attrs(/*initialStorage, */ attributes.begin(), attributes.end()) { + assert(llvm::none_of(attributes, + [](NamedAttribute attr) { return !attr.second; }) && + "cannot have null attributes"); + // TODO: llvm::array_pod_sort calls `qsort` which is somehow making heap + // allocations...? + if (!attributes.isSorted()) + std::sort(attrs.begin(), attrs.end()); +} + +void detail::OpAttributeList::set(StringRef name, Attribute value) { + assert(value && "attribute cannot be null"); + auto it = lookupAttribute(name); + if (it.second) { + it.first->second = value; + } else { + attrs.insert(it.first, {Identifier::get(name, value.getContext()), value}); + } +} + +void detail::OpAttributeList::set(Identifier name, Attribute value, + unsigned left, unsigned right) { + assert(value && "attribute cannot be null"); + auto it = lookupAttribute(name, left, right); + if (it.second) { + it.first->second = value; + } else { + // Need to perform a string lookup to do a sorted insert. + auto insertIt = lookupAttribute(name.strref()); + assert(!insertIt.second); + attrs.insert(insertIt.first, {name, value}); + } +} + +void detail::OpAttributeList::addOrSet(Identifier name, Attribute value) { + assert(value && "attribute cannot be null"); + // Go straight to string lookup. + auto it = lookupAttribute(name.strref()); + if (LLVM_UNLIKELY(it.second)) { + it.first->second = value; + } else { + attrs.insert(it.first, {name, value}); + } +} + +void detail::OpAttributeList::setOrAdd(StringRef name, Attribute value) { + assert(value && "attribute cannot be null"); + // Go straight to identifier lookup. + auto nameId = Identifier::get(name, value.getContext()); + auto it = lookupAttribute(nameId); + if (LLVM_LIKELY(it.second)) { + it.first->second = value; + } else { + // Need to perform a string lookup to do a sorted insert. + auto insertIt = lookupAttribute(name); + assert(!insertIt.second); + attrs.insert(insertIt.first, {nameId, value}); + } +} + +void detail::OpAttributeList::definitelySet(Identifier name, Attribute value, + unsigned left, unsigned right) { + auto it = lookupAttribute(name, left, right); + assert(it.second && "attribute to set not found"); + it.first->second = value; +} + +void detail::OpAttributeList::assign(AttributeRange attributes) { + assert(llvm::none_of(attributes, + [](NamedAttribute attr) { return !attr.second; }) && + "cannot have null attributes"); + // Re-using the existing storage may prevent a re-allocation. + attrs.assign(attributes.begin(), attributes.end()); + if (!attributes.isSorted()) + std::sort(attrs.begin(), attrs.end()); +} + +DictionaryAttr +detail::OpAttributeList::convertToDictionary(MLIRContext *context) const { + return DictionaryAttr::getWithSorted(context, + ArrayRef(attrs)); +} + //===----------------------------------------------------------------------===// // OperationState //===----------------------------------------------------------------------===// @@ -187,8 +273,7 @@ OperationState::OperationState(Location location, StringRef name, ValueRange operands, TypeRange types, - ArrayRef attributes, - BlockRange successors, + AttributeRange attributes, BlockRange successors, MutableArrayRef> regions) : location(location), name(name, location->getContext()), operands(operands.begin(), operands.end()), @@ -652,8 +737,8 @@ // - Operation Name // - Attributes // - Result Types - llvm::hash_code hash = llvm::hash_combine( - op->getName(), op->getAttrDictionary(), op->getResultTypes()); + llvm::hash_code hash = + llvm::hash_combine(op->getName(), op->getAttrs(), op->getResultTypes()); // - Operands for (Value operand : op->getOperands()) @@ -723,8 +808,7 @@ return true; // Compare the operation properties. - if (lhs->getName() != rhs->getName() || - lhs->getAttrDictionary() != rhs->getAttrDictionary() || + if (lhs->getName() != rhs->getName() || lhs->getAttrs() != rhs->getAttrs() || lhs->getNumRegions() != rhs->getNumRegions() || lhs->getNumSuccessors() != rhs->getNumSuccessors() || lhs->getNumOperands() != rhs->getNumOperands() || diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -462,6 +462,10 @@ // Symbol Use Lists //===----------------------------------------------------------------------===// +/// Abstraction of an attribute container. An attribute container can be the +/// operation itself, a dictionary attribute, or an array attribute. +using AttrContainer = llvm::PointerUnion; + /// Walk all of the symbol references within the given operation, invoking the /// provided callback for each found use. The callbacks takes as arguments: the /// use of the symbol, and the nested access chain to the attribute within the @@ -479,13 +483,14 @@ Operation *op, function_ref)> callback) { // Check to see if the operation has any attributes. - DictionaryAttr attrDict = op->getAttrDictionary(); - if (attrDict.empty()) + AttributeRange attrs = op->getAttrs(); + if (attrs.empty()) return WalkResult::advance(); // A worklist of a container attribute and the current index into the held - // attribute list. - SmallVector attrWorklist(1, attrDict); + // attribute list. The first container of attributes is always the operation + // itself. + SmallVector attrWorklist(1, op); SmallVector curAccessChain(1, /*Value=*/-1); // Process the symbol references within the given nested attribute range. @@ -517,15 +522,23 @@ WalkResult result = WalkResult::advance(); do { - Attribute attr = attrWorklist.back(); + AttrContainer container = attrWorklist.back(); int &index = curAccessChain.back(); ++index; // Process the given attribute, which is guaranteed to be a container. - if (auto dict = attr.dyn_cast()) - result = processAttrs(index, make_second_range(dict.getValue())); - else - result = processAttrs(index, attr.cast().getValue()); + if (auto attr = container.dyn_cast()) { + if (auto dictAttr = attr.dyn_cast()) { + result = + processAttrs(index, llvm::make_second_range(dictAttr.getValue())); + } else { + result = processAttrs(index, attr.cast().getValue()); + } + } else { + result = processAttrs( + index, + llvm::make_second_range(container.get()->getAttrs())); + } } while (!attrWorklist.empty() && !result.wasInterrupted()); return result; } @@ -811,8 +824,8 @@ /// Rebuild the given attribute container after replacing all references to a /// symbol with the updated attribute in 'accesses'. -static Attribute rebuildAttrAfterRAUW( - Attribute container, +static AttrContainer rebuildAttrAfterRAUW( + AttrContainer container, ArrayRef, SymbolRefAttr>> accesses, unsigned depth) { // Given a range of Attributes, update the ones referred to by the given @@ -838,21 +851,32 @@ return nextAccess.size() > depth + 1 && nextAccess[depth] == access[depth]; }); - attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1); + AttrContainer result = + rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1); + assert(result.is() && + "expected nested container to be an attribute"); + attr = result.get(); // Skip over all of the accesses that refer to the nested container. i += nestedAccesses.size(); } }; - if (auto dictAttr = container.dyn_cast()) { - auto newAttrs = llvm::to_vector<4>(dictAttr.getValue()); - updateAttrs(make_second_range(newAttrs)); - return DictionaryAttr::get(dictAttr.getContext(), newAttrs); + if (auto attr = container.dyn_cast()) { + if (auto dictAttr = attr.dyn_cast()) { + auto newAttrs = llvm::to_vector<4>(dictAttr.getValue()); + updateAttrs(make_second_range(newAttrs)); + return DictionaryAttr::get(dictAttr.getContext(), newAttrs); + } + auto newAttrs = llvm::to_vector<4>(attr.cast().getValue()); + updateAttrs(newAttrs); + return ArrayAttr::get(attr.getContext(), newAttrs); } - auto newAttrs = llvm::to_vector<4>(container.cast().getValue()); - updateAttrs(newAttrs); - return ArrayAttr::get(container.getContext(), newAttrs); + auto *op = container.get(); + auto opAttrs = llvm::to_vector<4>(op->getAttrs()); + updateAttrs(make_second_range(opAttrs)); + op->setAttrs(AttributeRange(opAttrs, /*sorted=*/true)); + return op; } /// Generates a new symbol reference attribute with a new leaf reference. @@ -869,9 +893,6 @@ template static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) { - // A collection of operations along with their new attribute dictionary. - std::vector> updatedAttrDicts; - // The current operation being processed. Operation *curOp = nullptr; @@ -881,10 +902,8 @@ // Generate a new attribute dictionary for the current operation by replacing // references to the old symbol. - auto generateNewAttrDict = [&] { - auto oldDict = curOp->getAttrDictionary(); - auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0); - return newDict.cast(); + auto updateOpAttrs = [&](Operation *op) { + rebuildAttrAfterRAUW(op, accessChains, /*depth=*/0); }; // Generate a new attribute to replace the given attribute. @@ -918,7 +937,7 @@ // for it. This means that we've finished processing the current // operation, so generate a new dictionary for it. if (curOp && symbolUse.getUser() != curOp) { - updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); + updateOpAttrs(curOp); accessChains.clear(); } @@ -932,14 +951,11 @@ // Check to see if we have a dangling op that needs to be processed. if (curOp) { - updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); + updateOpAttrs(curOp); curOp = nullptr; } } - // Update the attribute dictionaries as necessary. - for (auto &it : updatedAttrDicts) - it.first->setAttrs(it.second); return success(); } diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -179,11 +179,11 @@ LogicalResult mlir::detail::inferReturnTensorTypes( function_ref location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &retComponents)> componentTypeFn, MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { SmallVector retComponents; if (failed(componentTypeFn(context, location, operands, attributes, regions, @@ -204,9 +204,9 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { SmallVector inferredReturnTypes; auto retTypeFn = cast(op); - if (failed(retTypeFn.inferReturnTypes( - op->getContext(), op->getLoc(), op->getOperands(), - op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) + if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(), + op->getOperands(), op->getAttrs(), + op->getRegions(), inferredReturnTypes))) return failure(); if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes, op->getResultTypes())) diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -32,7 +32,10 @@ // - Operation pointer addDataToHash(hasher, op); // - Attributes - addDataToHash(hasher, op->getAttrDictionary()); + for (NamedAttribute attr : op->getAttrs()) { + addDataToHash(hasher, attr.first); + addDataToHash(hasher, attr.second); + } // - Blocks in Regions for (Region ®ion : op->getRegions()) { for (Block &block : region) { diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1330,10 +1330,9 @@ // TODO: Handle failure. state.types.clear(); - if (failed(concept->inferReturnTypes( - state.getContext(), state.location, state.operands, - state.attributes.getDictionary(state.getContext()), state.regions, - state.types))) + if (failed(concept->inferReturnTypes(state.getContext(), state.location, + state.operands, state.attributes, + state.regions, state.types))) return; break; } diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -86,7 +86,7 @@ // folds in-place. The constant passed in may not correspond to the real // runtime value, so in-place updates are not allowed. SmallVector originalOperands(op->getOperands()); - DictionaryAttr originalAttrs = op->getAttrDictionary(); + SmallVector originalAttrs = op->getAttrs().toVector(); // Simulate the result of folding this operation to a constant. If folding // fails or was an in-place fold, mark the results as overdefined. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -241,7 +241,7 @@ public: OperationTransactionState() = default; OperationTransactionState(Operation *op) - : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()), + : op(op), loc(op->getLoc()), attrs(op->getAttrs().toVector()), operands(op->operand_begin(), op->operand_end()), successors(op->successor_begin(), op->successor_end()) {} @@ -261,7 +261,7 @@ private: Operation *op; LocationAttr loc; - DictionaryAttr attrs; + SmallVector attrs; SmallVector operands; SmallVector successors; }; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -564,8 +564,8 @@ } static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, - DictionaryAttr attrs) { - printer.printOptionalAttrDict(attrs.getValue()); + AttributeRange attrs) { + printer.printOptionalAttrDict(attrs); } static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, @@ -792,7 +792,7 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType() != operands[1].getType()) { return emitOptionalError(location, "operand type mismatch ", @@ -805,7 +805,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, Optional location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, + AttributeRange attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. auto operandType = operands.front().getType(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2042,7 +2042,7 @@ let extraClassDeclaration = [{ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::mlir::AttributeRange attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); return ::mlir::success(); @@ -2226,7 +2226,7 @@ let extraClassDeclaration = [{ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::mlir::AttributeRange attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { inferredReturnTypes.assign({operands[0].getType()}); return ::mlir::success(); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -160,9 +160,9 @@ for (int j = 0; j < e; ++j) { std::array values = {{fop.getArgument(i), fop.getArgument(j)}}; SmallVector inferredReturnTypes; - if (succeeded(OpTy::inferReturnTypes( - context, llvm::None, values, op->getAttrDictionary(), - op->getRegions(), inferredReturnTypes))) { + if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values, + op->getAttrs(), op->getRegions(), + inferredReturnTypes))) { OperationState state(location, OpTy::getOperationName()); // TODO: Expand to regions. OpTy::build(b, state, values, op->getAttrs()); diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -66,7 +66,7 @@ # ODS: let builders = # ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, # ODS-NEXT: "ValueRange":$outputs, -# ODS-NEXT: CArg<"ArrayRef", "{}">:$attributes), +# ODS-NEXT: CArg<"AttributeRange", "{}">:$attributes), # ODS: $_state.addOperands(inputs); # ODS-NEXT: $_state.addOperands(outputs); diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -127,7 +127,7 @@ // DEF: odsState.addAttribute(aAttrAttrName(odsState.name), some-const-builder-call(odsBuilder, aAttr)); // DEF: void AOp::build( -// DEF: ::llvm::ArrayRef<::mlir::NamedAttribute> attributes +// DEF: ::mlir::AttributeRange attributes // DEF: odsState.addAttributes(attributes); // Test the above but with prefix. @@ -239,7 +239,7 @@ // DEF: odsState.addAttribute(getAAttrAttrName(odsState.name), some-const-builder-call(odsBuilder, aAttr)); // DEF: void AgetOp::build( -// DEF: ::llvm::ArrayRef<::mlir::NamedAttribute> attributes +// DEF: ::mlir::AttributeRange attributes // DEF: odsState.addAttributes(attributes); def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">; diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -55,7 +55,7 @@ // CHECK: class AOpAdaptor { // CHECK: public: -// CHECK: AOpAdaptor(::mlir::ValueRange values +// CHECK: AOpAdaptor(::mlir::MLIRContext *mlirContext // CHECK: ::mlir::ValueRange getODSOperands(unsigned index); // CHECK: ::mlir::Value getA(); // CHECK: ::mlir::ValueRange getB(); @@ -93,7 +93,7 @@ // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, int integer = 0); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes, unsigned numRegions) // CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); // CHECK: void print(::mlir::OpAsmPrinter &p); // CHECK: ::mlir::LogicalResult verify(); @@ -105,7 +105,7 @@ // DEFS-LABEL: NS::AOp definitions -// DEFS: AOpAdaptor::AOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsOperands(values), odsAttrs(attrs), odsRegions(regions) +// DEFS: AOpAdaptor::AOpAdaptor(::mlir::MLIRContext *mlirContext, ::mlir::ValueRange values, ::mlir::AttributeRange attrs, ::mlir::RegionRange regions) : odsContext(mlirContext), odsOperands(values), odsAttrs(attrs), odsRegions(regions) // DEFS: ::mlir::RegionRange AOpAdaptor::getRegions() // DEFS: ::mlir::RegionRange AOpAdaptor::getSomeRegions() // DEFS-NEXT: return odsRegions.drop_front(1); @@ -125,8 +125,9 @@ } // CHECK-LABEL: AttrSizedOperandOpAdaptor( +// CHECK-SAME: ::mlir::MLIRContext *mlirContext // CHECK-SAME: ::mlir::ValueRange values -// CHECK-SAME: ::mlir::DictionaryAttr attrs +// CHECK-SAME: ::mlir::AttributeRange attrs // CHECK: ::mlir::ValueRange getA(); // CHECK: ::mlir::ValueRange getB(); // CHECK: ::mlir::Value getC(); @@ -198,7 +199,7 @@ // CHECK_LABEL: class NS_HCollectiveParamsOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes = {}) // Check suppression of "separate arg, separate result" build method for an op // with single variadic arg and single variadic result (since it will be @@ -210,7 +211,7 @@ // CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op : // CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes = {}); // Check suppression of "separate arg, collective result" build method for an op // with single variadic arg and non variadic result (since it will be @@ -222,7 +223,7 @@ // CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op : // CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes = {}); // Check suppression of "separate arg, collective result" build method for an op // with single variadic arg and > 1 variadic result (since it will be @@ -236,7 +237,7 @@ // CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a); // CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes = {}); // Check default value of `attributes` for the `genUseOperandAsResultTypeCollectiveParamBuilder` builder def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperandsAndResultType]> { @@ -246,9 +247,9 @@ // CHECK_LABEL: class NS_IOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes = {}); // Check default value of `attributes` for the `genInferredTypeCollectiveParamBuilder` builder def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterfaceMethods]> { @@ -259,8 +260,8 @@ // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes = {}); // Test usage of OpTraitList getting flattened during emission. def NS_KOp : NS_Op<"k_op", [IsolatedFromAbove, diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -50,7 +50,7 @@ } // CHECK-LABEL: OpD definitions -// CHECK: void OpD::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) +// CHECK: void OpD::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes) // CHECK: odsState.addTypes({attr.second.cast<::mlir::TypeAttr>().getValue()}); def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { @@ -59,7 +59,7 @@ } // CHECK-LABEL: OpE definitions -// CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) +// CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes) // CHECK: odsState.addTypes({attr.second.getType()}); def OpF : NS_Op<"one_variadic_result_op", []> { @@ -110,7 +110,7 @@ let results = (outs AnyTensor:$result); } -// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) +// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::mlir::AttributeRange attributes) // CHECK: odsState.addTypes({operands[0].getType()}); // Test with inferred shapes and interleaved with operands/attributes. diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -41,7 +41,7 @@ let extraClassDeclaration = [{ static ::mlir::LogicalResult inferReturnTypes( ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, - ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::ValueRange operands, ::mlir::AttributeRange attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { ::mlir::Builder b(context); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -457,7 +457,7 @@ let builders = [ OpBuilder< (ins "ValueRange":$inputs, "ValueRange":$outputs, - CArg<"ArrayRef", "{{}">:$attributes), + CArg<"AttributeRange", "{{}">:$attributes), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); @@ -481,7 +481,7 @@ OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, - CArg<"ArrayRef", "{{}">:$attributes), + CArg<"AttributeRange", "{{}">:$attributes), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); @@ -500,7 +500,7 @@ }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, - CArg<"ArrayRef", "{{}">:$attributes), + CArg<"AttributeRange", "{{}">:$attributes), [{{ $_state.addOperands(operands); $_state.addAttributes(attributes); @@ -541,7 +541,7 @@ , OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, {1}, - CArg<"ArrayRef", "{{}">:$attributes), + CArg<"AttributeRange", "{{}">:$attributes), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -77,11 +77,13 @@ // // {0}: The name of the attribute specifying the segment sizes. const char *adapterSegmentSizeAttrInitCode = R"( - assert(odsAttrs && "missing segment size attribute for op"); + assert(!odsAttrs.empty() && "missing segment size attribute for op"); auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); )"; const char *opSegmentSizeAttrInitCode = R"( - auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>(); + auto sizeAttr = (*this)->getOpAttributes().get({0}AttrName(), {0}AttrLeft(), + {0}AttrRight()) + .cast<::mlir::DenseIntElementsAttr>(); )"; const char *attrSizedSegmentValueRangeCalcCode = R"( const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin(); @@ -136,6 +138,28 @@ )"; +/// Attribute subrange get. +/// +/// {0}: attribute name. +const char *attrSubrangeGet = "(*this)->getOpAttributes().get({0}AttrName(), " + "{0}AttrLeft(), {0}AttrRight())"; +const char *attrSubrangeGetNamed = "(*this)->getOpAttributes().getNamed({0}" + "AttrName(), {0}AttrLeft(), {0}AttrRight())"; +const char *attrSubrangeSet = "(*this)->getOpAttributes().set({0}AttrName(), " + "{1}, {0}AttrLeft(), {0}AttrRight())"; +const char *attrSubrangeDefinitelySet = + "(*this)->getOpAttributes().definitelySet({0}AttrName(), {1}, " + "{0}AttrLeft(), {0}AttrRight())"; +const char *attrSubrangeErase = "(*this)->getOpAttributes().erase({0}AttrName()" + ", {0}AttrLeft(), {0}AttrRight())"; +const char *attrSubrangeGetOrSet = R"( + (*this)->getOpAttributes().getOrSet({0}AttrName(), [&] { + return {1}; + }, {0}AttrLeft(), {0}AttrRight()) +)"; + +const char *adaptorGetAttr = "odsAttrs.get(\"{0}\")"; + //===----------------------------------------------------------------------===// // Utility structs and functions //===----------------------------------------------------------------------===// @@ -155,7 +179,7 @@ // via getValueAsString. static inline bool hasStringAttribute(const Record &record, StringRef fieldName) { - auto valueInit = record.getValueInit(fieldName); + auto *valueInit = record.getValueInit(fieldName); return isa(valueInit); } @@ -163,8 +187,7 @@ const auto &operand = op.getOperand(index); if (!operand.name.empty()) return std::string(operand.name); - else - return std::string(formatv("{0}_{1}", generatedArgName, index)); + return std::string(formatv("{0}_{1}", generatedArgName, index)); } // Returns true if we can use unwrapped value for the given `attr` in builders. @@ -180,6 +203,11 @@ //===----------------------------------------------------------------------===// namespace { +struct AttrSubrange { + StringRef name; + unsigned left, right; +}; + // Helper class to emit a record into the given output stream. class OpEmitter { public: @@ -358,6 +386,9 @@ // The emitter containing all of the locally emitted verification functions. const StaticVerifierFunctionEmitter &staticVerifierEmitter; + + // Pre-computed attribute subranges. + SmallVector attrSubranges; }; } // end anonymous namespace @@ -370,13 +401,15 @@ // an operand (the generated function call returns an OperandRange); // - resultGet corresponds to the name of the function to get an result (the // generated function call returns a ValueRange); -static void populateSubstitutions(const Operator &op, const char *attrGet, - const char *operandGet, const char *resultGet, - FmtContext &ctx) { +static void populateSubstitutions(const Operator &op, bool opRequired, + const char *attrGet, const char *operandGet, + const char *resultGet, FmtContext &ctx) { // Populate substitutions for attributes and named operands. - for (const auto &namedAttr : op.getAttributes()) + for (const auto &namedAttr : op.getAttributes()) { ctx.addSubst(namedAttr.name, - formatv("{0}(\"{1}\")", attrGet, namedAttr.name)); + formatv(attrGet, opRequired ? op.getGetterName(namedAttr.name) + : namedAttr.name)); + } for (int i = 0, e = op.getNumOperands(); i < e; ++i) { auto &value = op.getOperand(i); if (value.name.empty()) @@ -412,7 +445,8 @@ static void genAttributeVerifier(const Operator &op, const char *attrGet, const Twine &emitErrorPrefix, bool emitVerificationRequiringOp, - FmtContext &ctx, OpMethodBody &body) { + FmtContext &ctx, OpMethodBody &body, + bool useOpGetter = false) { for (const auto &namedAttr : op.getAttributes()) { const auto &attr = namedAttr.attr; if (attr.isDerivedAttr()) @@ -442,8 +476,10 @@ !hasConditionToEmit) continue; - body << formatv(" {\n auto {0} = {1}(\"{2}\");\n", varName, attrGet, - attrName); + body << formatv(" {\n auto {0} = ", varName, attrGet, attrName) + << formatv(attrGet, + useOpGetter ? op.getGetterName(attrName) : attrName) + << ";\n"; if (!emitVerificationRequiringOp && !allowMissingAttr) { body << " if (!" << varName << ") return " << emitErrorPrefix @@ -472,11 +508,63 @@ } } +static SmallVector collectAttrNames(const Operator &op) { + // A map of attribute names (including implicit attributes) registered to the + // current operation, to whether the attribute is mandatory. + llvm::MapVector uniqueAttrNames; + + // Enumerate the attribute names of this op, assigning each a relative + // ordering. + for (const NamedAttribute &namedAttr : op.getAttributes()) { + const Attribute &attr = namedAttr.attr; + bool isMandatory = + !attr.isDerivedAttr() && !attr.isOptional() && !attr.hasDefaultValue(); + uniqueAttrNames.insert({namedAttr.name, isMandatory}); + } + // Include key attributes from several traits as implicitly registered. + const char *operandSizes = "operand_segment_sizes"; + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) + uniqueAttrNames.insert({operandSizes, true}); + const char *attrSizes = "result_segment_sizes"; + if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) + uniqueAttrNames.insert({attrSizes, true}); + + if (uniqueAttrNames.empty()) + return {}; + + // Sort the attribute names so that the search lower and upper bounds can + // be determined. + std::vector> attrNames = + uniqueAttrNames.takeVector(); + std::sort(attrNames.begin(), attrNames.end(), + [](auto &lhs, auto &rhs) { return lhs.first < rhs.first; }); + + // Determine the number of **mandatory** attributes to the left and right of + // each attribute. + SmallVector mandatoryLeft(attrNames.size(), 0); + unsigned numMandatory = 0; + for (auto it : llvm::enumerate(attrNames)) { + mandatoryLeft[it.index()] = numMandatory; + numMandatory += it.value().second; + } + + SmallVector attrSubranges; + attrSubranges.reserve(attrNames.size()); + for (auto it : llvm::zip(attrNames, mandatoryLeft)) { + attrSubranges.push_back(AttrSubrange{ + .name = std::get<0>(it).first, + .left = std::get<1>(it), + .right = numMandatory - std::get<1>(it) - std::get<0>(it).second}); + } + return attrSubranges; +} + OpEmitter::OpEmitter(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter) : def(op.getDef()), op(op), opClass(op.getCppClassName(), op.getExtraClassDeclaration()), - staticVerifierEmitter(staticVerifierEmitter) { + staticVerifierEmitter(staticVerifierEmitter), + attrSubranges(collectAttrNames(op)) { verifyCtx.withOp("(*this->getOperation())"); verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); @@ -534,25 +622,6 @@ #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O) void OpEmitter::genAttrNameGetters() { - // A map of attribute names (including implicit attributes) registered to the - // current operation, to the relative order in which they were registered. - llvm::MapVector attributeNames; - - // Enumerate the attribute names of this op, assigning each a relative - // ordering. - auto addAttrName = [&](StringRef name) { - unsigned index = attributeNames.size(); - attributeNames.insert({name, index}); - }; - for (const NamedAttribute &namedAttr : op.getAttributes()) - addAttrName(namedAttr.name); - // Include key attributes from several traits as implicitly registered. - std::string operandSizes = "operand_segment_sizes"; - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) - addAttrName(operandSizes); - std::string attrSizes = "result_segment_sizes"; - if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) - addAttrName(attrSizes); // Emit the getAttributeNames method. { @@ -561,19 +630,18 @@ OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline)); ERROR_IF_PRUNED(method, "getAttributeNames", op); auto &body = method->body(); - if (attributeNames.empty()) { + if (attrSubranges.empty()) { body << " return {};"; } else { body << " static ::llvm::StringRef attrNames[] = {"; - llvm::interleaveComma(llvm::make_first_range(attributeNames), body, - [&](StringRef attrName) { - body << "::llvm::StringRef(\"" << attrName - << "\")"; - }); + llvm::interleaveComma( + attrSubranges, body, [&](const AttrSubrange attrName) { + body << "::llvm::StringRef(\"" << attrName.name << "\")"; + }); body << "};\n return ::llvm::makeArrayRef(attrNames);"; } } - if (attributeNames.empty()) + if (attrSubranges.empty()) return; // Emit the getAttributeNameForIndex methods. @@ -593,7 +661,7 @@ OpMethod::MP_Static), "::mlir::OperationName name, unsigned index"); ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); - method->body() << "assert(index < " << attributeNames.size() + method->body() << "assert(index < " << attrSubranges.size() << " && \"invalid attribute index\");\n" " return name.getAbstractOperation()" "->getAttributeNames()[index];"; @@ -602,8 +670,8 @@ // Generate the AttrName methods, that expose the attribute names to // users. const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; - for (const std::pair &attrIt : attributeNames) { - for (StringRef name : op.getGetterNames(attrIt.first)) { + for (auto attrIt : llvm::enumerate(attrSubranges)) { + for (StringRef name : op.getGetterNames(attrIt.value().name)) { std::string methodName = (name + "AttrName").str(); // Generate the non-static variant. @@ -613,7 +681,7 @@ OpMethod::Property(OpMethod::MP_Inline)); ERROR_IF_PRUNED(method, methodName, op); method->body() - << llvm::formatv(attrNameMethodBody, attrIt.second).str(); + << llvm::formatv(attrNameMethodBody, attrIt.index()).str(); } // Generate the static variant. @@ -624,11 +692,33 @@ "::mlir::OperationName", "name"); ERROR_IF_PRUNED(method, methodName, op); method->body() << llvm::formatv(attrNameMethodBody, - "name, " + Twine(attrIt.second)) + "name, " + Twine(attrIt.index())) .str(); } } } + + // AttrLeft, AttrRight + for (auto &subrange : attrSubranges) { + for (StringRef name : op.getGetterNames(subrange.name)) { + { + std::string methodName = (name + "AttrLeft").str(); + auto *method = opClass.addMethodAndPrune( + "unsigned", methodName, + OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static)); + ERROR_IF_PRUNED(method, methodName, op); + method->body() << " return " << subrange.left << ";"; + } + { + std::string methodName = (name + "AttrRight").str(); + auto *method = opClass.addMethodAndPrune( + "unsigned", methodName, + OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static)); + ERROR_IF_PRUNED(method, methodName, op); + method->body() << " return " << subrange.right << ";"; + } + } + } } void OpEmitter::genAttrGetters() { @@ -646,19 +736,15 @@ auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); ERROR_IF_PRUNED(method, name, op); auto &body = method->body(); - body << " auto attr = " << name << "Attr();\n"; + body << " auto attr = "; if (attr.hasDefaultValue()) { - // Returns the default value if not set. - // TODO: this is inefficient, we are recreating the attribute for every - // call. This should be set instead. - std::string defaultValue = std::string( + body << formatv( + attrSubrangeGetOrSet, name, tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); - body << " if (!attr)\n return " - << tgfmt(attr.getConvertFromStorageCall(), - &fctx.withSelf(defaultValue)) - << ";\n"; + } else { + body << name << "Attr()"; } - body << " return " + body << ";\n return " << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) << ";\n"; }; @@ -672,7 +758,7 @@ if (!method) return; auto &body = method->body(); - body << " return (*this)->getAttr(" << name << "AttrName()).template "; + body << " return " << formatv(attrSubrangeGet, name) << ".template "; if (attr.isOptional() || attr.hasDefaultValue()) body << "dyn_cast_or_null<"; else @@ -768,17 +854,24 @@ Attribute attr) { auto *method = opClass.addMethodAndPrune( "void", (setterName + "Attr").str(), attr.getStorageType(), "attr"); - if (method) - method->body() << " (*this)->setAttr(" << getterName - << "AttrName(), attr);"; + if (method) { + method->body() << " " + << formatv(attr.isOptional() || attr.hasDefaultValue() + ? attrSubrangeSet + : attrSubrangeDefinitelySet, + getterName, "attr") + << ";\n"; + } }; for (const NamedAttribute &namedAttr : op.getAttributes()) { - if (!namedAttr.attr.isDerivedAttr()) - for (auto names : llvm::zip(op.getSetterNames(namedAttr.name), - op.getGetterNames(namedAttr.name))) - emitAttrWithStorageType(std::get<0>(names), std::get<1>(names), - namedAttr.attr); + if (namedAttr.attr.isDerivedAttr()) + continue; + for (auto names : llvm::zip(op.getSetterNames(namedAttr.name), + op.getGetterNames(namedAttr.name))) { + emitAttrWithStorageType(std::get<0>(names), std::get<1>(names), + namedAttr.attr); + } } } @@ -792,13 +885,15 @@ "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str()); if (!method) return; - method->body() << " return (*this)->removeAttr(" << op.getGetterName(name) - << "AttrName());"; + method->body() << " return " + << formatv(attrSubrangeErase, op.getGetterName(name)) + << ";\n"; }; - for (const NamedAttribute &namedAttr : op.getAttributes()) + for (const NamedAttribute &namedAttr : op.getAttributes()) { if (namedAttr.attr.isOptional()) emitRemoveAttr(namedAttr.name); + } } // Generates the code to compute the start and end index of an operand or result @@ -935,8 +1030,9 @@ // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - std::string attr = op.getGetterName("operand_segment_sizes") + "AttrName()"; - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + op.getGetterName("operand_segment_sizes")) + .str(); } generateNamedOperandGetters( @@ -967,10 +1063,12 @@ << " auto mutableRange = " "::mlir::MutableOperandRange(getOperation(), " "range.first, range.second"; - if (attrSizedOperands) - body << ", ::mlir::MutableOperandRange::OperandSegment(" << i - << "u, *getOperation()->getAttrDictionary().getNamed(" - << op.getGetterName("operand_segment_sizes") << "AttrName()))"; + if (attrSizedOperands) { + body << ", ::mlir::MutableOperandRange::OperandSegment(" << i << "u, *" + << formatv(attrSubrangeGetNamed, + op.getGetterName("operand_segment_sizes")) + << ")"; + } body << ");\n"; // If this operand is a nested variadic, we split the range into a @@ -979,7 +1077,7 @@ if (operand.isVariadicOfVariadic()) { // body << " return " - "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" + "mutableRange.split(*(*this)->getNamedAttr(" << op.getGetterName( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) << "AttrName()));\n"; @@ -1023,8 +1121,9 @@ // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { - std::string attr = op.getGetterName("result_segment_sizes") + "AttrName()"; - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + op.getGetterName("result_segment_sizes")) + .str(); } generateValueRangeStartAndEnd( @@ -1176,8 +1275,7 @@ ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), {1}.location, {1}.operands, - {1}.attributes.getDictionary({1}.getContext()), - /*regions=*/{{}, inferredReturnTypes))) + {1}.attributes, /*regions=*/{{}, inferredReturnTypes))) {1}.addTypes(inferredReturnTypes); else ::llvm::report_fatal_error("Failed to infer result type(s).");)", @@ -1236,8 +1334,8 @@ paramList.emplace_back("::mlir::ValueRange", "operands"); // Provide default value for `attributes` when its the last parameter StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; - paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", attributesDefaultValue); + paramList.emplace_back("::mlir::AttributeRange", "attributes", + attributesDefaultValue); if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); @@ -1274,8 +1372,7 @@ paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); - paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", "{}"); + paramList.emplace_back("::mlir::AttributeRange", "attributes", "{}"); auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, std::move(paramList)); // If the builder is redundant, skip generating the method @@ -1313,8 +1410,7 @@ ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes; if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), {1}.location, operands, - {1}.attributes.getDictionary({1}.getContext()), - /*regions=*/{{}, inferredReturnTypes))) {{)", + {1}.attributes, /*regions=*/{{}, inferredReturnTypes))) {{)", opClass.getClassName(), builderOpState); if (numVariadicResults == 0 || numNonVariadicResults != 0) body << " assert(inferredReturnTypes.size()" @@ -1362,8 +1458,7 @@ paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); - paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", "{}"); + paramList.emplace_back("::mlir::AttributeRange", "attributes", "{}"); auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, std::move(paramList)); // If the builder is redundant, skip generating the method @@ -1490,8 +1585,8 @@ paramList.emplace_back("::mlir::ValueRange", "operands"); // Provide default value for `attributes` when its the last parameter StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; - paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", attributesDefaultValue); + paramList.emplace_back("::mlir::AttributeRange", "attributes", + attributesDefaultValue); if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); @@ -2032,9 +2127,8 @@ << "].getType()"; return body << "attributes[" << arg.operandOrAttributeIndex() << "].getType()"; - } else { - return body << tgfmt(*type.getType().getBuilderCall(), &fctx); } + return body << tgfmt(*type.getType().getBuilderCall(), &fctx); }; for (int i = 0, e = op.getNumResults(); i != e; ++i) { @@ -2071,7 +2165,7 @@ if (hasStringAttribute(def, "assemblyFormat")) return; - auto valueInit = def.getValueInit("printer"); + auto *valueInit = def.getValueInit("printer"); StringInit *stringInit = dyn_cast(valueInit); if (!stringInit) return; @@ -2085,22 +2179,74 @@ method->body() << " " << tgfmt(printer, &fctx); } +static void genOpOrAdaptorVerifier(OpMethodBody &body, const Operator &op, + bool opRequired, FmtContext &verifyCtx) { + const char *const checkAttrSizedValueSegmentsCode = R"( + { + auto sizeAttr = {3}.cast<::mlir::DenseIntElementsAttr>(); + auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>() + .getNumElements(); + if (numElements != {1}) + return {4}"'{0}' attribute for specifying {2} segments must have {1} " + "elements, but got ") << numElements; + } + )"; + + std::string emitError = + opRequired ? "emitOpError(" + : ("emitError(loc, \"'" + op.getOperationName() + "' op \""); + + // Use subrange getter on the Operation if inside of one. + const char *const attrGetter = opRequired ? attrSubrangeGet : adaptorGetAttr; + const auto checkSegmentAttr = [&](StringRef attrName, unsigned count, + StringRef errNoun) { + body << formatv( + checkAttrSizedValueSegmentsCode, attrName, count, errNoun, + formatv(attrGetter, opRequired ? op.getGetterName(attrName) : attrName), + emitError); + }; + + // Verify a few traits first so that we can use + // getODSOperands()/getODSResults() in the rest of the verifier. + for (auto &trait : op.getTraits()) { + if (auto *t = dyn_cast(&trait)) { + if (t->getFullyQualifiedTraitName() == + "::mlir::OpTrait::AttrSizedOperandSegments") { + checkSegmentAttr("operand_segment_sizes", op.getNumOperands(), + "operand"); + } else if (t->getFullyQualifiedTraitName() == + "::mlir::OpTrait::AttrSizedResultSegments") { + checkSegmentAttr("result_segment_sizes", op.getNumResults(), "result"); + } + } + } + + populateSubstitutions(op, opRequired, attrGetter, "getODSOperands", + "", verifyCtx); + genAttributeVerifier(op, attrGetter, emitError, + /*emitVerificationRequiringOp=*/false, verifyCtx, body, + /*useOpGetter=*/opRequired); +} + void OpEmitter::genVerifier() { auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify"); ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); - body << " if (::mlir::failed(" << op.getAdaptorName() - << "(*this).verify((*this)->getLoc()))) " - << "return ::mlir::failure();\n"; + // body << " if (::mlir::failed(" << op.getAdaptorName() + // << "(*this).verify((*this)->getLoc()))) " + // << "return ::mlir::failure();\n"; + genOpOrAdaptorVerifier(body, op, /*opRequired=*/true, verifyCtx); auto *valueInit = def.getValueInit("verifier"); StringInit *stringInit = dyn_cast(valueInit); bool hasCustomVerify = stringInit && !stringInit->getValue().empty(); - populateSubstitutions(op, "(*this)->getAttr", "this->getODSOperands", - "this->getODSResults", verifyCtx); + populateSubstitutions(op, /*opRequired=*/true, attrSubrangeGet, + "this->getODSOperands", "this->getODSResults", + verifyCtx); - genAttributeVerifier(op, "(*this)->getAttr", "emitOpError(", - /*emitVerificationRequiringOp=*/true, verifyCtx, body); + genAttributeVerifier(op, attrSubrangeGet, "emitOpError(", + /*emitVerificationRequiringOp=*/true, verifyCtx, body, + /*useOpGetter=*/true); genOperandResultVerifier(body, op.getOperands(), "operand"); genOperandResultVerifier(body, op.getResults(), "result"); @@ -2336,9 +2482,9 @@ // Add the native and interface traits. for (const auto &trait : op.getTraits()) { - if (auto opTrait = dyn_cast(&trait)) + if (auto *opTrait = dyn_cast(&trait)) opClass.addTrait(opTrait->getFullyQualifiedTraitName()); - else if (auto opTrait = dyn_cast(&trait)) + else if (auto *opTrait = dyn_cast(&trait)) opClass.addTrait(opTrait->getFullyQualifiedTraitName()); } } @@ -2410,21 +2556,46 @@ }; } // end namespace +/// Code template to get an attribute in an OpAdaptor. +/// +/// {0}: Attribute storage type. +/// {1}: Attribute name. +/// {2}: `cast` for required attributes, `dyn_cast_or_null` for optional or +/// default-valued attributes. +static const char *const opAdaptorGetAttr = R"( + {0} attr = odsAttrs.get("{1}").{2}<{0}>(); +)"; + +/// Code template to materialize an attribute's default value in an OpAdaptor, +/// if a value for the attribute was not found. +/// +/// TODO: This is inefficient because default values are materialized each call. +/// However, OpAdaptor is an immutable view on an Operation-like. +/// +/// {0}: Code snippet to construct the attribute's default value; +static const char *const opAdaptorConstructDefaultAttr = R"( + if (!attr) + attr = {0}; +)"; + OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) : op(op), adaptor(op.getAdaptorName()) { + adaptor.newField("::mlir::MLIRContext *", "odsContext"); adaptor.newField("::mlir::ValueRange", "odsOperands"); - adaptor.newField("::mlir::DictionaryAttr", "odsAttrs"); + adaptor.newField("::mlir::AttributeRange", "odsAttrs"); adaptor.newField("::mlir::RegionRange", "odsRegions"); const auto *attrSizedOperands = op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); { SmallVector paramList; + paramList.emplace_back("::mlir::MLIRContext *", "mlirContext"); paramList.emplace_back("::mlir::ValueRange", "values"); - paramList.emplace_back("::mlir::DictionaryAttr", "attrs", - attrSizedOperands ? "" : "nullptr"); + paramList.emplace_back("::mlir::AttributeRange", "attrs", + attrSizedOperands ? "" : "{}"); paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList)); + constructor->addMemberInitializer("odsContext", "mlirContext"); constructor->addMemberInitializer("odsOperands", "values"); constructor->addMemberInitializer("odsAttrs", "attrs"); constructor->addMemberInitializer("odsRegions", "regions"); @@ -2433,8 +2604,9 @@ { auto *constructor = adaptor.addConstructorAndPrune( llvm::formatv("{0}&", op.getCppClassName()).str(), "op"); + constructor->addMemberInitializer("odsContext", "op->getContext()"); constructor->addMemberInitializer("odsOperands", "op->getOperands()"); - constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); + constructor->addMemberInitializer("odsAttrs", "op->getAttrs()"); constructor->addMemberInitializer("odsRegions", "op->getRegions()"); } @@ -2453,35 +2625,28 @@ /*getOperandCallPattern=*/"odsOperands[{0}]"); FmtContext fctx; - fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); + fctx.withBuilder("::mlir::Builder(odsContext)"); auto emitAttr = [&](StringRef name, Attribute attr) { auto *method = adaptor.addMethodAndPrune(attr.getStorageType(), name); ERROR_IF_PRUNED(method, "Adaptor::" + name, op); auto &body = method->body(); - body << " assert(odsAttrs && \"no attributes when constructing adapter\");" - << "\n " << attr.getStorageType() << " attr = " - << "odsAttrs.get(\"" << name << "\")."; - if (attr.hasDefaultValue() || attr.isOptional()) - body << "dyn_cast_or_null<"; - else - body << "cast<"; - body << attr.getStorageType() << ">();\n"; + body << formatv(opAdaptorGetAttr, attr.getStorageType(), name, + attr.hasDefaultValue() || attr.isOptional() + ? "dyn_cast_or_null" + : "cast"); if (attr.hasDefaultValue()) { - // Use the default value if attribute is not set. - // TODO: this is inefficient, we are recreating the attribute for every - // call. This should be set instead. - std::string defaultValue = std::string( + body << formatv( + opAdaptorConstructDefaultAttr, tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); - body << " if (!attr)\n attr = " << defaultValue << ";\n"; } body << " return attr;\n"; }; { auto *m = - adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes"); + adaptor.addMethodAndPrune("::mlir::AttributeRange", "getAttributes"); ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); m->body() << " return odsAttrs;"; } @@ -2489,7 +2654,7 @@ const auto &name = namedAttr.name; const auto &attr = namedAttr.attr; if (!attr.isDerivedAttr()) { - for (auto emitName : op.getGetterNames(name)) + for (auto &emitName : op.getGetterNames(name)) emitAttr(emitName, attr); } } @@ -2528,44 +2693,10 @@ auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify", "::mlir::Location", "loc"); ERROR_IF_PRUNED(method, "verify", op); - auto &body = method->body(); - - const char *checkAttrSizedValueSegmentsCode = R"( - { - auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); - auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); - if (numElements != {1}) - return emitError(loc, "'{0}' attribute for specifying {2} segments " - "must have {1} elements, but got ") << numElements; - } - )"; - - // Verify a few traits first so that we can use - // getODSOperands()/getODSResults() in the rest of the verifier. - for (auto &trait : op.getTraits()) { - if (auto *t = dyn_cast(&trait)) { - if (t->getFullyQualifiedTraitName() == - "::mlir::OpTrait::AttrSizedOperandSegments") { - body << formatv(checkAttrSizedValueSegmentsCode, - "operand_segment_sizes", op.getNumOperands(), - "operand"); - } else if (t->getFullyQualifiedTraitName() == - "::mlir::OpTrait::AttrSizedResultSegments") { - body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes", - op.getNumResults(), "result"); - } - } - } - FmtContext verifyCtx; - populateSubstitutions(op, "odsAttrs.get", "getODSOperands", - "", verifyCtx); - genAttributeVerifier(op, "odsAttrs.get", - Twine("emitError(loc, \"'") + op.getOperationName() + - "' op \"", - /*emitVerificationRequiringOp*/ false, verifyCtx, body); - - body << " return ::mlir::success();"; + genOpOrAdaptorVerifier(method->body(), op, + /*opRequired=*/false, verifyCtx); + method->body() << " return ::mlir::success();\n"; } void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -693,8 +693,7 @@ ::llvm::SmallVector<::mlir::Type> inferredReturnTypes; if (::mlir::failed({0}::inferReturnTypes(parser.getContext(), result.location, result.operands, - result.attributes.getDictionary(parser.getContext()), - result.regions, inferredReturnTypes))) + result.attributes, result.regions, inferredReturnTypes))) return ::mlir::failure(); result.addTypes(inferredReturnTypes); )"; @@ -1654,7 +1653,7 @@ { bool printTerminator = true; if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{ - printTerminator = !term->getAttrDictionary().empty() || + printTerminator = !term->getAttrs().empty() || term->getNumOperands() != 0 || term->getNumResults() != 0; } @@ -1742,7 +1741,7 @@ body << op.getGetterName(attr->getVar()->name) << "Attr()"; } else if (isa(element)) { - body << "getOperation()->getAttrDictionary()"; + body << "getOperation()->getAttrs()"; } else if (auto *operand = dyn_cast(element)) { body << op.getGetterName(operand->getVar()->name) << "()";