diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -107,17 +107,26 @@ Insert operations using the given OpBuilder that computes the result shape. Only one of this method or `reifyReturnTypeShapesPerResultDim` needs to be overriden by the - operation. + operation. This interface is supposed to be workable during dialect + conversion (e.g. convert from tensor world to buffer world), + where `getOperand` may be invalid. For example, some ops (e.g. + dynamic_reshape(input, target_shape)) may depend on their operands + to calculate the result shape. When the `matchAndRewrite ` method + of a conversion pattern is called, the operands of the op to convert + may have been converted into other types, which makes it invalid to + call the `getOperand` method of such op directly inside the + conversion pattern. To solve this problem, this interface follows + the design of the conversion pattern, that is, accepting passed in + operands to avoid calling `getOperand` directly inside the interface + implementation. }], /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"reifyReturnTypeShapes", /*args=*/(ins "::mlir::OpBuilder&":$builder, + "::mlir::ValueRange":$operands, "::mlir::SmallVectorImpl &":$reifiedReturnShapes), /*methodBody=*/[{}], - /*defaultImplementation=*/[{ - auto operands = this->getOperation()->getOperands(); - return this->reifyReturnTypeShapesWithOperands(builder, operands, reifiedReturnShapes); - }] + /*defaultImplementation=*/[{ return ::mlir::failure(); }] >, InterfaceMethod< /*desc=*/[{Reify the shape computation for the operation. @@ -146,33 +155,6 @@ :$reifiedReturnShapes), /*methodBody=*/[{}], /*defaultImplementation=*/[{ return ::mlir::failure(); }] - >, - InterfaceMethod< - /*desc=*/[{Reify the shape computation for the operation. - - Insert operations using the given OpBuilder that computes the - result shape. This interface is similar to `reifyReturnTypeShapes` - except that it accepts a user-provided version of operands instead - of using the normal `getOperand` way to fetch. This interface is - supposed to be used during dialect conversion (e.g. convert from - tensor world to buffer world), where `getOperand` may be invalid. - For example, some ops (e.g. dynamic_reshape(input, target_shape)) - may depend on their operands to calculate the result shape. When - the `matchAndRewrite ` method of a conversion pattern is called, - the operands of the op to convert may have been converted into other - types, which makes it invalid to call the `getOperand` method of - such op directly inside the conversion pattern. To solve this - problem, this interface follows the design of the conversion pattern, - that is, accepting passed in operands to avoid calling `getOperand` - directly inside the interface implementation. - }], - /*retTy=*/"::mlir::LogicalResult", - /*methodName=*/"reifyReturnTypeShapesWithOperands", - /*args=*/(ins "::mlir::OpBuilder&":$builder, - "::mlir::ValueRange":$operands, - "::mlir::SmallVectorImpl &":$reifiedReturnShapes), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return ::mlir::failure(); }] > ]; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -734,8 +734,8 @@ // check if the op implements the first interface method or the second, and // get the value to use appropriately. SmallVector reifiedResultShapes; - if (succeeded( - shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) { + if (succeeded(shapedTypeOp.reifyReturnTypeShapes( + builder, result.getOwner()->getOperands(), reifiedResultShapes))) { if (reifiedResultShapes.size() <= resultNumber) return nullptr; Value resultShape = reifiedResultShapes[resultNumber]; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -748,7 +748,8 @@ } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( - OpBuilder &builder, llvm::SmallVectorImpl &shapes) { + OpBuilder &builder, ValueRange operands, + llvm::SmallVectorImpl &shapes) { shapes = SmallVector{ builder.createOrFold(getLoc(), getOperand(0), 0)}; return success(); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -137,7 +137,7 @@ // Use permutations of 2 args as operands. auto shapedOp = cast(op); SmallVector shapes; - if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)) || + if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || !llvm::hasSingleElement(shapes)) return; for (auto it : llvm::enumerate(shapes)) {