diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -438,12 +438,12 @@ // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. Value getStridedElementPtr(Location loc, Type elementTypePtr, - Value descriptor, ArrayRef indices, + Value descriptor, ValueRange indices, ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const; Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, - ArrayRef indices, ConversionPatternRewriter &rewriter, + ValueRange indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const; protected: diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -124,7 +124,7 @@ // with AffineMap that has static strides. Extend to handle dynamic strides. spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, - ArrayRef indices, Location loc, + ValueRange indices, Location loc, OpBuilder &builder); /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its diff --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h --- a/mlir/include/mlir/TableGen/OpClass.h +++ b/mlir/include/mlir/TableGen/OpClass.h @@ -86,6 +86,7 @@ OpMethod(StringRef retType, StringRef name, StringRef params, Property property, bool declOnly); + virtual ~OpMethod() = default; OpMethodBody &body(); @@ -96,13 +97,13 @@ bool isPrivate() const; // Writes the method as a declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const; + virtual void writeDeclTo(raw_ostream &os) const; // Writes the method as a definition to the given `os`. `namePrefix` is the // prefix to be prepended to the method name (typically namespaces for // qualifying the method definition). - void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const; -private: +protected: Property properties; // Whether this method only contains a declaration. bool isDeclOnly; @@ -110,6 +111,26 @@ OpMethodBody methodBody; }; +// Class for holding an op's constructor method for C++ code emission. +class OpConstructor : public OpMethod { +public: + OpConstructor(StringRef retType, StringRef name, StringRef params, + Property property, bool declOnly) + : OpMethod(retType, name, params, property, declOnly){}; + + // Add member initializer to constructor initializing `name` with `value`. + void addMemberInitializer(StringRef name, StringRef value); + + // Writes the method as a definition to the given `os`. `namePrefix` is the + // prefix to be prepended to the method name (typically namespaces for + // qualifying the method definition). + void writeDefTo(raw_ostream &os, StringRef namePrefix) const override; + +private: + // Member initializers. + std::string memberInitializers; +}; + // A class used to emit C++ classes from Tablegen. Contains a list of public // methods and a list of private fields to be emitted. class Class { @@ -121,7 +142,7 @@ OpMethod::Property = OpMethod::MP_None, bool declOnly = false); - OpMethod &newConstructor(StringRef params = "", bool declOnly = false); + OpConstructor &newConstructor(StringRef params = "", bool declOnly = false); // Creates a new field in this class. void newField(StringRef type, StringRef name, StringRef defaultValue = ""); @@ -136,6 +157,7 @@ protected: std::string className; + SmallVector constructors; SmallVector methods; SmallVector fields; }; diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -58,6 +58,9 @@ // Returns this op's C++ class name prefixed with namespaces. std::string getQualCppClassName() const; + // Returns the name of op's adaptor C++ class. + std::string getAdaptorName() const; + /// A class used to represent the decorators of an operator variable, i.e. /// argument or result. struct VariableDecorator { diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -795,8 +795,8 @@ } Value ConvertToLLVMPattern::getStridedElementPtr( - Location loc, Type elementTypePtr, Value descriptor, - ArrayRef indices, ArrayRef strides, int64_t offset, + Location loc, Type elementTypePtr, Value descriptor, ValueRange indices, + ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const { MemRefDescriptor memRefDescriptor(descriptor); @@ -818,8 +818,7 @@ } Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type, - Value memRefDesc, - ArrayRef indices, + Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const { LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); @@ -2602,7 +2601,7 @@ // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value getSize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef shape, ArrayRef dynamicSizes, + ArrayRef shape, ValueRange dynamicSizes, unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -579,7 +579,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr( SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, - ArrayRef indices, Location loc, OpBuilder &builder) { + ValueRange indices, Location loc, OpBuilder &builder) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; @@ -591,6 +591,7 @@ } auto indexType = typeConverter.getIndexType(builder.getContext()); + SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); @@ -606,7 +607,7 @@ loc, indexType, IntegerAttr::get(indexType, offset)); assert(indices.size() == strides.size() && "must provide indices for all dimensions"); - for (auto index : enumerate(indices)) { + for (auto index : llvm::enumerate(indices)) { Value strideVal = builder.create( loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); Value update = diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp --- a/mlir/lib/TableGen/OpClass.cpp +++ b/mlir/lib/TableGen/OpClass.cpp @@ -120,6 +120,27 @@ } //===----------------------------------------------------------------------===// +// OpConstructor definitions +//===----------------------------------------------------------------------===// + +void mlir::tblgen::OpConstructor::addMemberInitializer(StringRef name, + StringRef value) { + memberInitializers.append(std::string(llvm::formatv( + "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value))); +} + +void mlir::tblgen::OpConstructor::writeDefTo(raw_ostream &os, + StringRef namePrefix) const { + if (isDeclOnly) + return; + + methodSignature.writeDefTo(os, namePrefix); + os << " " << memberInitializers << " {\n"; + methodBody.writeTo(os); + os << "}"; +} + +//===----------------------------------------------------------------------===// // Class definitions //===----------------------------------------------------------------------===// @@ -133,10 +154,11 @@ return methods.back(); } -tblgen::OpMethod &tblgen::Class::newConstructor(StringRef params, - bool declOnly) { - return newMethod("", getClassName(), params, OpMethod::MP_Constructor, - declOnly); +tblgen::OpConstructor &tblgen::Class::newConstructor(StringRef params, + bool declOnly) { + constructors.emplace_back("", getClassName(), params, + OpMethod::MP_Constructor, declOnly); + return constructors.back(); } void tblgen::Class::newField(StringRef type, StringRef name, @@ -152,7 +174,8 @@ bool hasPrivateMethod = false; os << "class " << className << " {\n"; os << "public:\n"; - for (const auto &method : methods) { + for (const auto &method : + llvm::concat(constructors, methods)) { if (!method.isPrivate()) { method.writeDeclTo(os); os << '\n'; @@ -163,7 +186,8 @@ os << '\n'; os << "private:\n"; if (hasPrivateMethod) { - for (const auto &method : methods) { + for (const auto &method : + llvm::concat(constructors, methods)) { if (method.isPrivate()) { method.writeDeclTo(os); os << '\n'; @@ -177,7 +201,8 @@ } void tblgen::Class::writeDefTo(raw_ostream &os) const { - for (const auto &method : methods) { + for (const auto &method : + llvm::concat(constructors, methods)) { method.writeDefTo(os, className); os << "\n\n"; } diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -59,6 +59,10 @@ return std::string(llvm::formatv("{0}.{1}", prefix, opName)); } +std::string tblgen::Operator::getAdaptorName() const { + return std::string(llvm::formatv("{0}OperandAdaptor", getCppClassName())); +} + StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); } StringRef tblgen::Operator::getCppClassName() const { return cppClassName; } diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -49,14 +49,14 @@ // CHECK: class AOpOperandAdaptor { // CHECK: public: -// CHECK: AOpOperandAdaptor(ArrayRef values -// CHECK: ArrayRef getODSOperands(unsigned index); +// CHECK: AOpOperandAdaptor(ValueRange values +// CHECK: ValueRange getODSOperands(unsigned index); // CHECK: Value a(); -// CHECK: ArrayRef b(); +// CHECK: ValueRange b(); // CHECK: IntegerAttr attr1(); // CHECL: FloatAttr attr2(); // CHECK: private: -// CHECK: ArrayRef odsOperands; +// CHECK: ValueRange odsOperands; // CHECK: }; // CHECK: class AOp : public Op::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove @@ -106,12 +106,12 @@ } // CHECK-LABEL: AttrSizedOperandOpOperandAdaptor( -// CHECK-SAME: ArrayRef values +// CHECK-SAME: ValueRange values // CHECK-SAME: DictionaryAttr attrs -// CHECK: ArrayRef a(); -// CHECK: ArrayRef b(); +// CHECK: ValueRange a(); +// CHECK: ValueRange b(); // CHECK: Value c(); -// CHECK: ArrayRef d(); +// CHECK: ValueRange d(); // CHECK: DenseIntElementsAttr operand_segment_sizes(); // Check op trait for different number of operands diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -15,7 +15,7 @@ // CHECK-LABEL: OpA definitions // CHECK: OpAOperandAdaptor::OpAOperandAdaptor -// CHECK-NEXT: odsOperands = values +// CHECK-SAME: odsOperands(values), odsAttrs(attrs) // CHECK: void OpA::build // CHECK: Value input @@ -39,13 +39,13 @@ let arguments = (ins Variadic:$input1, AnyTensor:$input2, Variadic:$input3); } -// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input1 +// CHECK-LABEL: ValueRange OpDOperandAdaptor::input1 // CHECK-NEXT: return getODSOperands(0); // CHECK-LABEL: Value OpDOperandAdaptor::input2 // CHECK-NEXT: return *getODSOperands(1).begin(); -// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input3 +// CHECK-LABEL: ValueRange OpDOperandAdaptor::input3 // CHECK-NEXT: return getODSOperands(2); // CHECK-LABEL: Operation::operand_range OpD::input1 diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1890,27 +1890,38 @@ private: explicit OpOperandAdaptorEmitter(const Operator &op); - Class adapterClass; + Class adaptor; }; } // end namespace OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) - : adapterClass(op.getCppClassName().str() + "OperandAdaptor") { - adapterClass.newField("ArrayRef", "odsOperands"); - adapterClass.newField("DictionaryAttr", "odsAttrs"); + : adaptor(op.getAdaptorName()) { + adaptor.newField("ValueRange", "odsOperands"); + adaptor.newField("DictionaryAttr", "odsAttrs"); const auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments"); - auto &constructor = adapterClass.newConstructor( - attrSizedOperands - ? "ArrayRef values, DictionaryAttr attrs" - : "ArrayRef values, DictionaryAttr attrs = nullptr"); - constructor.body() << " odsOperands = values;\n"; - constructor.body() << " odsAttrs = attrs;\n"; + { + auto &constructor = adaptor.newConstructor( + attrSizedOperands + ? "ValueRange values, DictionaryAttr attrs" + : "ValueRange values, DictionaryAttr attrs = nullptr"); + constructor.addMemberInitializer("odsOperands", "values"); + constructor.addMemberInitializer("odsAttrs", "attrs"); + } + + { + auto &constructor = adaptor.newConstructor( + llvm::formatv("{0}& op", op.getCppClassName()).str()); + constructor.addMemberInitializer("odsOperands", + "op.getOperation()->getOperands()"); + constructor.addMemberInitializer("odsAttrs", + "op.getOperation()->getAttrDictionary()"); + } std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); - generateNamedOperandGetters(op, adapterClass, sizeAttrInit, - /*rangeType=*/"ArrayRef", + generateNamedOperandGetters(op, adaptor, sizeAttrInit, + /*rangeType=*/"ValueRange", /*rangeBeginCall=*/"odsOperands.begin()", /*rangeSizeCall=*/"odsOperands.size()", /*getOperandCallPattern=*/"odsOperands[{0}]"); @@ -1919,7 +1930,7 @@ fctx.withBuilder("mlir::Builder(odsAttrs.getContext())"); auto emitAttr = [&](StringRef name, Attribute attr) { - auto &body = adapterClass.newMethod(attr.getStorageType(), name).body(); + auto &body = adaptor.newMethod(attr.getStorageType(), name).body(); body << " assert(odsAttrs && \"no attributes when constructing adapter\");" << "\n " << attr.getStorageType() << " attr = " << "odsAttrs.get(\"" << name << "\")."; @@ -1949,11 +1960,11 @@ } void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { - OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os); + OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os); } void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { - OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os); + OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os); } // Emits the opcode enum and op classes.