diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -429,14 +429,14 @@ ```c++ // All result-types/operands/attributes have one aggregate parameter. -static void build(Builder *tblgen_builder, OperationState &tblgen_state, +static void build(Builder *odsBuilder, OperationState &odsState, ArrayRef resultTypes, ValueRange operands, ArrayRef attributes); // Each result-type/operand/attribute has a separate parameter. The parameters // for attributes are of mlir::Attribute types. -static void build(Builder *tblgen_builder, OperationState &tblgen_state, +static void build(Builder *odsBuilder, OperationState &odsState, Type i32_result, Type f32_result, ..., Value i32_operand, Value f32_operand, ..., IntegerAttr i32_attr, FloatAttr f32_attr, ...); @@ -445,20 +445,20 @@ // for attributes are raw values unwrapped with mlir::Attribute instances. // (Note that this builder will not always be generated. See the following // explanation for more details.) -static void build(Builder *tblgen_builder, OperationState &tblgen_state, +static void build(Builder *odsBuilder, OperationState &odsState, Type i32_result, Type f32_result, ..., Value i32_operand, Value f32_operand, ..., APInt i32_attr, StringRef f32_attr, ...); // Each operand/attribute has a separate parameter but result type is aggregate. -static void build(Builder *tblgen_builder, OperationState &tblgen_state, +static void build(Builder *odsBuilder, OperationState &odsState, ArrayRef resultTypes, Value i32_operand, Value f32_operand, ..., IntegerAttr i32_attr, FloatAttr f32_attr, ...); // All operands/attributes have aggregate parameters. // Generated if InferTypeOpInterface interface is specified. -static void build(Builder *tblgen_builder, OperationState &tblgen_state, +static void build(Builder *odsBuilder, OperationState &odsState, ValueRange operands, ArrayRef attributes); @@ -1099,7 +1099,7 @@ * The op's traits (e.g., commutative) are modelled along with the op in the registry. * The op's operand/return type constraints are modelled along with the op in - the registry (see [Shape inference](#shape-inference) discussion below), + the registry (see [Shape inference](ShapeInference.md) discussion below), this allows (e.g.) optimized concise syntax in textual dumps. * Behavior of the op is documented along with the op with a summary and a description. The description is written in markdown and extracted for @@ -1156,49 +1156,6 @@ Printing is effectively the inverse of the parsing function generated with the mnemonic string serving as a template. -### Shape inference - -Type constraints are along (at least) three axis: 1) elemental type, 2) rank -(including static or dynamic), 3) dimensions. While some ops have no compile -time fixed shape (e.g., output shape is dictated by data) we could still have -some knowledge of constraints/bounds in the system for that op (e.g., the output -of a `tf.where` is at most the size of the input data). And so there are -additional valuable constraints that could be captured even without full -knowledge. - -Initially the shape inference will be declaratively specified using: - -* Constraint on the operands of an operation directly. For example - constraining the input type to be tensor/vector elements or that the - elemental type be of a specific type (e.g., output of sign is of elemental - type `i1`) or class (e.g., float like). -* Constraints across operands and results of an operation. For example, - enabling specifying equality constraints on type/constituents of a type - (shape and elemental type) between operands and results (e.g., the output - type of an add is the same as those of the input operands). - -In general there is an input/output transfer function which maps the inputs to -the outputs (e.g., given input X and Y [or slices thereof] with these sizes, the -output is Z [or this slice thereof]). Such a function could be used to determine -the output type (shape) for given input type (shape). - -But shape functions are determined by attributes and could be arbitrarily -complicated with a wide-range of specification possibilities. Equality -relationships are common (e.g., the elemental type of the output matches the -primitive type of the inputs, both inputs have exactly the same type [primitive -type and shape]) and so these should be easy to specify. Algebraic relationships -would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0 -is `[n+n, m]` matrix), while some ops only have defined shapes under certain -cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if -`b == c`). As ops are also verified, the shape inference need only specify rules -for the allowed cases (e.g., shape inference for matmul can ignore the case -where `b != c`), which would simplify type constraint specification. - -Instead of specifying an additional mechanism to specify a shape transfer -function, the reference implementation of the operation will be used to derive -the shape function. The reference implementation is general and can support the -arbitrary computations needed to specify output shapes. - [TableGen]: https://llvm.org/docs/TableGen/index.html [TableGenIntro]: https://llvm.org/docs/TableGen/LangIntro.html [TableGenRef]: https://llvm.org/docs/TableGen/LangRef.html diff --git a/mlir/docs/ShapeInference.md b/mlir/docs/ShapeInference.md new file mode 100644 --- /dev/null +++ b/mlir/docs/ShapeInference.md @@ -0,0 +1,72 @@ +# Shape inference + +Shape inference as discussed here is considered a specific instance of type +inference for [ShapedType][ShapedType]. Type constraints are along (at least) +three axis: 1) elemental type, 2) rank (including static or dynamic), 3) +dimensions. While some operations have no compile time fixed shape (e.g., output +shape is dictated by data) we could still have some knowledge of +constraints/bounds in the system for that operation (e.g., the output of a +`tf.where` is at most the size of the input data). That is, there are additional +valuable constraints that could be captured even without full knowledge of the +shape. + +Type inference is currently modelled executionally for op creation using the +[`InferTypeOpInterface`][InferTypeOpInterface], while +`InferShapedTypeOpInterface` is used to implement the shape and element type +inference. The return type can often be deduced from the deduced return shape +and elemental type (queryable from `InferShapedTypeOpInterface`) and so type +inference for tensor types can be implemented with `InferShapedTypeOpInterface`. + +## Shape functions + +The C++ interfaces are the base mechanism whereby shape inference is queried and +executed, but not the intended way to specify shape constraints in general. + +Initially the shape inference will be declaratively specified using: + +* Constraints on the operands of an operation directly. For example + constraining the input type to be tensor/vector elements or that the + elemental type be of a specific type (e.g., output of computing the size + of a value is of elemental type `i1`) or class (e.g., float like). +* Constraints across operands and results of an operation. + + - For example, specifying equality constraints on type/constituents of a + type (shape and elemental type) between operands and results (e.g., the + output type of an add is the same as those of the input operands). + +NOTE: The C++ shape functions are an intermediate step until the shape dialect +is more full-fledged, at which point the C++ functions should become the +exceptional case. + +## Testing + +Shape inference is currently tested alongside type inference by +`TestReturnTypeDriver` in the test dialect. The driver performs two checks: + +1. Verification that the return types specified matches the infered types. This + explicit check will be removed and made part of Op verificaton instead. +2. Test the creation of Ops without specifying the return type explicitly in + function `testCreateFunctions` by creating new binary Ops (Op classes + specified in `TestReturnTypeDriver`) using 1) all operands to + `testCreateFunctions` as both operands, and 2) using combinations of input + operands of the function. + +## WIP/Future considerations + +Shape functions are determined by attributes and could be arbitrarily +complicated with a wide-range of specification possibilities. Equality +relationships are common (e.g., the elemental type of the output matches the +primitive type of the inputs, both inputs have exactly the same type [primitive +type and shape]) and so these should be easy to specify. Algebraic relationships +would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0 +is `[n+n, m]` matrix), while some ops only have defined shapes under certain +cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if `b +== c`). + +Instead of specifying an additional mechanism to specify a shape transfer +function, the reference implementation of the operation will be used to derive +the shape function. The reference implementation is general and can support the +arbitrary computations needed to specify output shapes. + +[InferTypeOpInterface]: https://github.com/llvm/llvm-project/tree/master/mlir/include/mlir/Analysis/InferTypeOpInterface.td +[ShapedType]: https://github.com/llvm/llvm-project/tree/master/mlir/include/mlir/IR/StandardTypes.h diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/mlir/include/mlir/Analysis/InferTypeOpInterface.h --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.h +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.h @@ -17,28 +17,100 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" namespace mlir { +/// ShapedTypeComponents that represents the components of a ShapedType. +/// The components consist of +/// - A ranked or unranked shape with the dimension specification match those +/// of ShapeType's getShape() (e.g., dynamic dimension represented using +/// ShapedType::kDynamicSize) +/// - A element type, may be unset (nullptr) +/// - A attribute, may be unset (nullptr) +/// Used by ShapedType type inferences. +class ShapedTypeComponents { + /// Internal storage type for shape. + using ShapeStorageT = SmallVector; + +public: + /// Default construction is an unranked shape. + ShapedTypeComponents() : ranked(false), elementType(nullptr), attr(nullptr){}; + + template ::value>> + ShapedTypeComponents(Arg &&arg, Type elementType = nullptr, + Attribute attr = nullptr) + : dims(std::forward(arg)), ranked(true), elementType(elementType), + attr(attr) {} + ShapedTypeComponents(ArrayRef vec, Type elementType = nullptr, + Attribute attr = nullptr) + : dims(vec.begin(), vec.end()), ranked(true), elementType(elementType), + attr(attr) {} + + /// Return the dimensions of the shape. + /// Requires: shape is ranked. + ArrayRef getDims() const { + assert(ranked && "requires ranked shape"); + return dims; + } + + /// Return whether the shape has a rank. + bool hasRank() const { return ranked; }; + + /// Return the element type component. + Type getElementType() const { return elementType; }; + + /// Return the raw attribute component. + Attribute getAttribute() const { return attr; }; + +private: + ShapeStorageT dims; + bool ranked; + Type elementType; + Attribute attr; +}; + #include "mlir/Analysis/InferTypeOpInterface.h.inc" +namespace detail { +// Helper function to infer return tensor returns types given element and shape +// inference function. +// +// TODO: Consider generating typedefs for trait member functions if this usage +// becomes more common. +LogicalResult inferReturnTensorTypes( + function_ref location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &retComponents)> + componentTypeFn, + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferedReturnTypes); +} // namespace detail + namespace OpTrait { + +/// Tensor type inference trait that constructs a tensor from the infered +/// shape and elemental types. +/// Requires: Op implements functions of InferShapedTypeOpInterface. template -class TypeOpInterfaceDefault - : public TraitBase { +class InferTensorType : public TraitBase { public: - /// Returns whether two arrays are equal as strongest check for compatibility - /// by default. - static bool isCompatibleReturnTypes(ArrayRef lhs, ArrayRef rhs) { - return lhs == rhs; - }; + static LogicalResult + inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, ArrayRef attributes, + RegionRange regions, + SmallVectorImpl &inferedReturnTypes) { + return ::mlir::detail::inferReturnTensorTypes( + ConcreteType::inferReturnTypeComponents, context, location, operands, + attributes, regions, inferedReturnTypes); + } }; -} // namespace OpTrait +} // namespace OpTrait } // namespace mlir #endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.td +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td @@ -22,9 +22,8 @@ // mismatch). def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> { let description = [{ - Interface to access a registered method to infer the return types for an - operation that could be used during op construction, verification or - type inference. + Interface to infer the return types for an operation that could be used + during op construction, verification or type inference. }]; let methods = [ @@ -38,7 +37,8 @@ }], /*retTy=*/"LogicalResult", /*methodName=*/"inferReturnTypes", - /*args=*/(ins "Optional":$location, + /*args=*/(ins "MLIRContext*":$context, + "Optional":$location, "ValueRange":$operands, "ArrayRef":$attributes, "RegionRange":$regions, @@ -62,4 +62,38 @@ ]; } +def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> { + let description = [{ + Interface to infer the components of a ShapedType returned by an operation + that could be used during op construction, verification or shape inference. + + The components consists of element type, shape and raw attribute. + }]; + + let methods = [ + StaticInterfaceMethod< + /*desc=*/[{Infer the components of return type of shape containter. + + The method takes an optional location which, if set, will be used to + report errors on. The operands and attributes correspond to those with + which an Operation would be created (e.g., as used in Operation::create) + and the regions of the op. + + Unknown (e.g., unranked) shape and nullptrs for element type and attribute + may be returned by this function while returning success. E.g., partial + population of components is not error condition. + }], + /*retTy=*/"LogicalResult", + /*methodName=*/"inferReturnTypeComponents", + /*args=*/(ins "MLIRContext*":$context, + "Optional":$location, + "ValueRange":$operands, + "ArrayRef":$attributes, + "RegionRange":$regions, + "SmallVectorImpl&": + $inferedReturnShapes) + >, + ]; +} + #endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1539,7 +1539,7 @@ // following signatures: // // ```c++ - // static void build(Builder *, OperationState &tblgen_state, + // static void build(Builder *, OperationState &odsState, // Type , Type , ..., // Value , Value , ..., // Attribute , Attribute , ...); @@ -1547,7 +1547,7 @@ // * where the attributes follow the same declaration order as in the op. // // ```c++ - // static void build(Builder *, OperationState &tblgen_state, + // static void build(Builder *, OperationState &odsState, // ArrayRef resultTypes, // ArrayRef operands, // ArrayRef attributes); diff --git a/mlir/lib/Analysis/InferTypeOpInterface.cpp b/mlir/lib/Analysis/InferTypeOpInterface.cpp --- a/mlir/lib/Analysis/InferTypeOpInterface.cpp +++ b/mlir/lib/Analysis/InferTypeOpInterface.cpp @@ -12,11 +12,36 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/InferTypeOpInterface.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Types.h" -#include "llvm/ADT/SmallVector.h" + +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; namespace mlir { #include "mlir/Analysis/InferTypeOpInterface.cpp.inc" } // namespace mlir + +LogicalResult mlir::detail::inferReturnTensorTypes( + function_ref location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &retComponents)> + componentTypeFn, + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferedReturnTypes) { + SmallVector retComponents; + if (failed(componentTypeFn(context, location, operands, attributes, regions, + retComponents))) + return failure(); + for (auto shapeAndType : retComponents) { + assert(shapeAndType.getAttribute() == nullptr && "attribute not supported"); + if (shapeAndType.hasRank()) + inferedReturnTypes.push_back(RankedTensorType::get( + shapeAndType.getDims(), shapeAndType.getElementType())); + else + inferedReturnTypes.push_back( + UnrankedTensorType::get(shapeAndType.getElementType())); + } + return success(); +} diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -295,7 +295,7 @@ } LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( - llvm::Optional location, ValueRange operands, + MLIRContext *, Optional location, ValueRange operands, ArrayRef attributes, RegionRange regions, SmallVectorImpl &inferedReturnTypes) { if (operands[0].getType() != operands[1].getType()) { @@ -307,6 +307,30 @@ return success(); } +LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferedComponents) { + // Create return type consisting of the first element of each shape of the + // input operands or unknown for unranked operand. + std::vector shape; + shape.reserve(operands.size()); + for (auto operandType : operands.getTypes()) { + if (auto sval = operandType.dyn_cast()) { + if (sval.hasRank()) + shape.push_back(sval.getShape().front()); + else + shape.push_back(ShapedType::kDynamicSize); + } else { + return emitOptionalError(location, "only shaped type operands allowed"); + } + } + inferedComponents.reserve(1); + auto type = IntegerType::get(17, context); + inferedComponents.emplace_back(shape, type); + return success(); +} + // Static initialization for Test dialect registration. static mlir::DialectRegistration testDialect; diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -402,6 +402,21 @@ let results = (outs AnyTensor); } +def InferTensorType : NativeOpTrait<"InferTensorType">; +def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if", + [ + // Op implements infer type op interface. + InferTypeOpInterface, + // The op will have methods implementing the ShapedType type infer interface. + DeclareOpInterfaceMethods, + // The op produces tensors and will use the ShapedType type infer interface + // along with knowledge that it is producing Tensors to infer shape. + InferTensorType + ]> { + let arguments = (ins AnyTensor, AnyTensor); + let results = (outs AnyTensor); +} + def IsNotScalar : Constraint>; def UpdateAttr : Pat<(I32ElementsAttrOp $attr), diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -58,50 +58,71 @@ //===----------------------------------------------------------------------===// namespace { -struct ReturnTypeOpMatch : public RewritePattern { - ReturnTypeOpMatch(MLIRContext *ctx) - : RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) { - } - - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const final { - if (auto retTypeFn = dyn_cast(op)) { - SmallVector values(op->getOperands()); +// Generate ops for each instance where the type can be succesfully infered. +template +static void invokeCreateWithInferedReturnType(Operation *op) { + auto *context = op->getContext(); + auto fop = op->getParentOfType(); + auto location = UnknownLoc::get(context); + OpBuilder b(op); + b.setInsertionPointAfter(op); + + // Use permutations of 2 args as operands. + assert(fop.getNumArguments() >= 2); + for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { + for (int j = 0; j < e; ++j) { + std::array values = {fop.getArgument(i), fop.getArgument(j)}; SmallVector inferedReturnTypes; - if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values, - op->getAttrs(), op->getRegions(), - inferedReturnTypes))) - return matchFailure(); - SmallVector resultTypes(op->getResultTypes()); - if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) - return op->emitOpError( - "inferred type incompatible with return type of operation"), - matchFailure(); - - // TODO(jpienaar): Split this out to make the test more focused. - // Create new op with unknown location to verify building with - // InferTypeOpInterface is triggered. - auto fop = op->getParentOfType(); - if (values[0] == fop.getArgument(0)) { - // Use the 2nd function argument if the first function argument is used - // when constructing the new op so that a new return type is inferred. - values[0] = fop.getArgument(1); - values[1] = fop.getArgument(1); + if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values, + op->getAttrs(), op->getRegions(), + inferedReturnTypes))) { + OperationState state(location, OpTy::getOperationName()); // TODO(jpienaar): Expand to regions. - rewriter.create( - UnknownLoc::get(op->getContext()), values, op->getAttrs()); + OpTy::build(&b, state, values, op->getAttrs()); + (void)b.createOperation(state); } } - return matchFailure(); } -}; +} struct TestReturnTypeDriver : public FunctionPass { void runOnFunction() override { - mlir::OwningRewritePatternList patterns; - populateWithGenerated(&getContext(), &patterns); - patterns.insert(&getContext()); - applyPatternsGreedily(getFunction(), patterns); + if (getFunction().getName() == "testCreateFunctions") { + std::vector ops; + // Collect ops to avoid triggering on inserted ops. + for (auto &op : getFunction().getBody().front()) + ops.push_back(&op); + // Generate test patterns for each, but skip terminator. + for (auto *op : llvm::makeArrayRef(ops).drop_back()) { + // Test create method of each of the Op classes below. The resultant + // output would be in reverse order underneath `op` from which + // the attributes and regions are used. + invokeCreateWithInferedReturnType(op); + invokeCreateWithInferedReturnType( + op); + }; + return; + } + + // Verification check. + // TODO: Move to ops that implement type infer interface. + getFunction().walk([this](Operation *op) -> void { + auto retTypeFn = dyn_cast(op); + if (!retTypeFn) + return; + auto *context = &getContext(); + SmallVector inferedReturnTypes; + if (failed(retTypeFn.inferReturnTypes( + context, op->getLoc(), op->getOperands(), op->getAttrs(), + op->getRegions(), inferedReturnTypes))) + return; + SmallVector resultTypes(op->getResultTypes()); + if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) { + op->emitOpError( + "inferred type incompatible with return type of operation"); + return; + } + }); } }; } // end anonymous namespace diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -56,18 +56,18 @@ // --- // DEF: void AOp::build( -// DEF: tblgen_state.addAttribute("aAttr", aAttr); -// DEF: tblgen_state.addAttribute("bAttr", bAttr); +// DEF: odsState.addAttribute("aAttr", aAttr); +// DEF: odsState.addAttribute("bAttr", bAttr); // DEF: if (cAttr) { -// DEF-NEXT: tblgen_state.addAttribute("cAttr", cAttr); +// DEF-NEXT: odsState.addAttribute("cAttr", cAttr); // DEF: void AOp::build( // DEF: some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr -// DEF: tblgen_state.addAttribute("aAttr", some-const-builder-call((*tblgen_builder), aAttr)); +// DEF: odsState.addAttribute("aAttr", some-const-builder-call((*odsBuilder), aAttr)); // DEF: void AOp::build( // DEF: ArrayRef attributes -// DEF: tblgen_state.addAttributes(attributes); +// DEF: odsState.addAttributes(attributes); // Test verify method // --- @@ -218,7 +218,7 @@ // DEF-LABEL: MixOperandsAndAttrs definitions // DEF-DAG: Value MixOperandsAndAttrs::operand() // DEF-DAG: Value MixOperandsAndAttrs::otherArg() -// DEF-DAG: void MixOperandsAndAttrs::build(Builder *tblgen_builder, OperationState &tblgen_state, FloatAttr attr, Value operand, FloatAttr otherAttr, Value otherArg) +// DEF-DAG: void MixOperandsAndAttrs::build(Builder *odsBuilder, OperationState &odsState, FloatAttr attr, Value operand, FloatAttr otherAttr, Value otherArg) // DEF-DAG: APFloat MixOperandsAndAttrs::attr() // DEF-DAG: APFloat MixOperandsAndAttrs::otherAttr() @@ -233,4 +233,4 @@ // DEF: bool UnitAttrOp::attr() { // DEF: return {{.*}} != nullptr -// DEF: build(Builder *tblgen_builder, OperationState &tblgen_state, /*optional*/UnitAttr attr) +// DEF: build(Builder *odsBuilder, OperationState &odsState, /*optional*/UnitAttr attr) diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -70,9 +70,9 @@ // CHECK: FloatAttr attr2Attr() // CHECK: Optional< APFloat > attr2(); // CHECK: static void build(Value val); -// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2) -// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2) -// CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef resultTypes, ValueRange operands, ArrayRef attributes) +// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(Builder *, OperationState &odsState, ArrayRef resultTypes, ValueRange operands, ArrayRef attributes) // CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result); // CHECK: void print(OpAsmPrinter &p); // CHECK: LogicalResult verify(); diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -19,12 +19,12 @@ // CHECK: void OpA::build // CHECK: Value input -// CHECK: tblgen_state.addOperands(input); +// CHECK: odsState.addOperands(input); // CHECK: void OpA::build // CHECK: ValueRange operands // CHECK: assert(operands.size() == 1u && "mismatched number of parameters"); -// CHECK: tblgen_state.addOperands(operands); +// CHECK: odsState.addOperands(operands); def OpB : NS_Op<"one_variadic_operand_op", []> { let arguments = (ins Variadic:$input); @@ -33,7 +33,7 @@ // CHECK-LABEL: OpB::build // CHECK: ValueRange input // CHECK-NOT: assert -// CHECK: tblgen_state.addOperands(input); +// CHECK: odsState.addOperands(input); def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> { let arguments = (ins Variadic:$input1, AnyTensor:$input2, Variadic:$input3); @@ -55,6 +55,6 @@ // CHECK-NEXT: return *getODSOperands(1).begin(); // CHECK-LABEL: OpD::build -// CHECK-NEXT: tblgen_state.addOperands(input1); -// CHECK-NEXT: tblgen_state.addOperands(input2); -// CHECK-NEXT: tblgen_state.addOperands(input3); +// CHECK-NEXT: odsState.addOperands(input1); +// CHECK-NEXT: odsState.addOperands(input2); +// CHECK-NEXT: odsState.addOperands(input3); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -15,7 +15,7 @@ // CHECK-LABEL: void OpA::build // CHECK: ArrayRef resultTypes, ValueRange operands // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types"); -// CHECK-NEXT: tblgen_state.addTypes(resultTypes); +// CHECK-NEXT: odsState.addTypes(resultTypes); def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> { let arguments = (ins I32:$x); @@ -23,20 +23,20 @@ } // CHECK-LABEL: OpB definitions -// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Type y, Value x) -// CHECK: tblgen_state.addTypes(y); -// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Value x) -// CHECK: tblgen_state.addTypes({x.getType()}); +// CHECK: void OpB::build(Builder *odsBuilder, OperationState &odsState, Type y, Value x) +// CHECK: odsState.addTypes(y); +// CHECK: void OpB::build(Builder *odsBuilder, OperationState &odsState, Value x) +// CHECK: odsState.addTypes({x.getType()}); def OpC : NS_Op<"three_normal_result_op", []> { let results = (outs I32:$x, /*unnamed*/I32, I32:$z); } // CHECK-LABEL: OpC definitions -// CHECK: void OpC::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, Type resultType1, Type z) -// CHECK-NEXT: tblgen_state.addTypes(x) -// CHECK-NEXT: tblgen_state.addTypes(resultType1) -// CHECK-NEXT: tblgen_state.addTypes(z) +// CHECK: void OpC::build(Builder *odsBuilder, OperationState &odsState, Type x, Type resultType1, Type z) +// CHECK-NEXT: odsState.addTypes(x) +// CHECK-NEXT: odsState.addTypes(resultType1) +// CHECK-NEXT: odsState.addTypes(z) def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">; def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> { @@ -45,8 +45,8 @@ } // CHECK-LABEL: OpD definitions -// CHECK: void OpD::build(Builder *, OperationState &tblgen_state, ValueRange operands, ArrayRef attributes) -// CHECK: tblgen_state.addTypes({attr.second.cast().getValue()}); +// CHECK: void OpD::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef attributes) +// CHECK: odsState.addTypes({attr.second.cast().getValue()}); def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { let arguments = (ins I32:$x, F32Attr:$attr); @@ -54,8 +54,8 @@ } // CHECK-LABEL: OpE definitions -// CHECK: void OpE::build(Builder *, OperationState &tblgen_state, ValueRange operands, ArrayRef attributes) -// CHECK: tblgen_state.addTypes({attr.second.getType()}); +// CHECK: void OpE::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef attributes) +// CHECK: odsState.addTypes({attr.second.getType()}); def OpF : NS_Op<"one_variadic_result_op", []> { let results = (outs Variadic:$x); @@ -64,7 +64,7 @@ // CHECK-LABEL: void OpF::build // CHECK-SAME: ArrayRef x // CHECK-NOT: assert -// CHECK: tblgen_state.addTypes(x); +// CHECK: odsState.addTypes(x); def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> { @@ -73,14 +73,14 @@ // CHECK-LABEL: OpG definitions -// CHECK: void OpG::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, ArrayRef y) -// CHECK-NEXT: tblgen_state.addTypes(x); -// CHECK-NEXT: tblgen_state.addTypes(y); +// CHECK: void OpG::build(Builder *odsBuilder, OperationState &odsState, Type x, ArrayRef y) +// CHECK-NEXT: odsState.addTypes(x); +// CHECK-NEXT: odsState.addTypes(y); // CHECK: void OpG::build // CHECK: ArrayRef resultTypes // CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types"); -// CHECK-NEXT: tblgen_state.addTypes(resultTypes); +// CHECK-NEXT: odsState.addTypes(resultTypes); def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> { let results = (outs Variadic:$output1, AnyTensor:$output2, Variadic:$output3); @@ -93,9 +93,9 @@ // CHECK-NEXT: return *getODSResults(1).begin(); // CHECK-LABEL: OpI::build -// CHECK-NEXT: tblgen_state.addTypes(output1); -// CHECK-NEXT: tblgen_state.addTypes(output2); -// CHECK-NEXT: tblgen_state.addTypes(output3); +// CHECK-NEXT: odsState.addTypes(output1); +// CHECK-NEXT: odsState.addTypes(output2); +// CHECK-NEXT: odsState.addTypes(output3); // Test that if the only operand is variadic, we access the first value in the // pack to set result type @@ -105,5 +105,5 @@ let results = (outs AnyTensor:$result); } -// CHECK-LABEL: OpK::build(Builder *tblgen_builder, OperationState &tblgen_state, ValueRange input) -// CHECK: tblgen_state.addTypes({input.front().getType()}); +// CHECK-LABEL: OpK::build(Builder *odsBuilder, OperationState &odsState, ValueRange input) +// CHECK: odsState.addTypes({input.front().getType()}); diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir --- a/mlir/test/mlir-tblgen/return-types.mlir +++ b/mlir/test/mlir-tblgen/return-types.mlir @@ -1,12 +1,23 @@ // RUN: mlir-opt %s -test-return-type -split-input-file -verify-diagnostics | FileCheck %s --dump-input-on-failure -// CHECK-LABEL: testReturnTypeOpInterface -func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) { - %good = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> - // CHECK: test.op_with_infer_type_if - // CHECK-SAME: tensor<20xi32> - // CHECK: test.op_with_infer_type_if - // CHECK-SAME: tensor<10xf32> +// CHECK-LABEL: testCreateFunctions +// This function tests invoking the create method with different inference +// methods. The attributes of the ops inside are used to test creation. +func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) { +// CHECK: "test.no_attributes" + %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> +// CHECK: "test.op_with_shaped_type_infer_type_if" +// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xi17> +// CHECK: "test.op_with_shaped_type_infer_type_if" +// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10x20xi17> +// CHECK: "test.op_with_shaped_type_infer_type_if" +// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20x10xi17> +// CHECK: "test.op_with_shaped_type_infer_type_if" +// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20x20xi17> +// CHECK: "test.op_with_infer_type_if" +// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> +// CHECK: "test.op_with_infer_type_if" +// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi32> return } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -58,8 +58,8 @@ //===----------------------------------------------------------------------===// static const char *const tblgenNamePrefix = "tblgen_"; -static const char *const generatedArgName = "tblgen_arg"; -static const char *const builderOpState = "tblgen_state"; +static const char *const generatedArgName = "odsArg"; +static const char *const builderOpState = "odsState"; // The logic to calculate the actual value range for a declared operand/result // of an op with variadic operands/results. Note that this logic is not for @@ -627,8 +627,9 @@ // TODO(jpienaar): Expand to handle regions. body << formatv(R"( SmallVector inferedReturnTypes; - if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands, - {1}.attributes, /*regions=*/{{}, inferedReturnTypes))) + if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(), + {1}.location, {1}.operands, {1}.attributes, + /*regions=*/{{}, inferedReturnTypes))) {1}.addTypes(inferedReturnTypes); else llvm::report_fatal_error("Failed to infer result type(s).");)", @@ -702,7 +703,7 @@ void OpEmitter::genInferedTypeCollectiveParamBuilder() { // TODO(jpienaar): Expand to support regions. const char *params = - "Builder *builder, OperationState &{0}, " + "Builder *odsBuilder, OperationState &{0}, " "ValueRange operands, ArrayRef attributes"; auto &m = opClass.newMethod("void", "build", formatv(params, builderOpState).str(), @@ -710,9 +711,10 @@ auto &body = m.body(); body << formatv(R"( SmallVector inferedReturnTypes; - if (succeeded({0}::inferReturnTypes({1}.location, operands, attributes, + if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(), + {1}.location, operands, attributes, /*regions=*/{{}, inferedReturnTypes))) - build(builder, tblgen_state, inferedReturnTypes, operands, attributes); + build(odsBuilder, odsState, inferedReturnTypes, operands, attributes); else llvm::report_fatal_error("Failed to infer result type(s).");)", opClass.getClassName(), builderOpState); @@ -878,7 +880,7 @@ auto numResults = op.getNumResults(); resultTypeNames.reserve(numResults); - paramList = "Builder *tblgen_builder, OperationState &"; + paramList = "Builder *odsBuilder, OperationState &"; paramList.append(builderOpState); switch (typeParamKind) { @@ -1000,7 +1002,7 @@ // If this is a raw value, then we need to wrap it in an Attribute // instance. FmtContext fctx; - fctx.withBuilder("(*tblgen_builder)"); + fctx.withBuilder("(*odsBuilder)"); std::string value = tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name); body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,