diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -107,11 +107,23 @@ Insert operations using the given OpBuilder that computes the result shape. Only one of this method or `reifyReturnTypeShapesPerResultDim` needs to be overriden by the - operation. + operation. This interface is supposed to be workable during dialect + conversion (e.g. convert from tensor world to buffer world), + where `getOperand` may be invalid. For example, some ops (e.g. + dynamic_reshape(input, target_shape)) may depend on their operands + to calculate the result shape. When the `matchAndRewrite ` method + of a conversion pattern is called, the operands of the op to convert + may have been converted into other types, which makes it invalid to + call the `getOperand` method of such op directly inside the + conversion pattern. To solve this problem, this interface follows + the design of the conversion pattern, that is, accepting passed in + operands to avoid calling `getOperand` directly inside the interface + implementation. }], /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"reifyReturnTypeShapes", /*args=*/(ins "::mlir::OpBuilder&":$builder, + "::mlir::ValueRange":$operands, "::mlir::SmallVectorImpl &":$reifiedReturnShapes), /*methodBody=*/[{}], /*defaultImplementation=*/[{ return ::mlir::failure(); }] diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -734,8 +734,8 @@ // check if the op implements the first interface method or the second, and // get the value to use appropriately. SmallVector reifiedResultShapes; - if (succeeded( - shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) { + if (succeeded(shapedTypeOp.reifyReturnTypeShapes( + builder, result.getOwner()->getOperands(), reifiedResultShapes))) { if (reifiedResultShapes.size() <= resultNumber) return nullptr; Value resultShape = reifiedResultShapes[resultNumber]; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -748,9 +748,10 @@ } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( - OpBuilder &builder, llvm::SmallVectorImpl &shapes) { + OpBuilder &builder, ValueRange operands, + llvm::SmallVectorImpl &shapes) { shapes = SmallVector{ - builder.createOrFold(getLoc(), getOperand(0), 0)}; + builder.createOrFold(getLoc(), operands.front(), 0)}; return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -137,7 +137,7 @@ // Use permutations of 2 args as operands. auto shapedOp = cast(op); SmallVector shapes; - if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)) || + if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || !llvm::hasSingleElement(shapes)) return; for (auto it : llvm::enumerate(shapes)) {