diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -114,7 +114,10 @@ /*args=*/(ins "::mlir::OpBuilder&":$builder, "::mlir::SmallVectorImpl &":$reifiedReturnShapes), /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return ::mlir::failure(); }] + /*defaultImplementation=*/[{ + auto operands = this->getOperation()->getOperands(); + return this->reifyReturnTypeShapesWithOperands(builder, operands, reifiedReturnShapes); + }] >, InterfaceMethod< /*desc=*/[{Reify the shape computation for the operation. @@ -143,6 +146,33 @@ :$reifiedReturnShapes), /*methodBody=*/[{}], /*defaultImplementation=*/[{ return ::mlir::failure(); }] + >, + InterfaceMethod< + /*desc=*/[{Reify the shape computation for the operation. + + Insert operations using the given OpBuilder that computes the + result shape. This interface is similar to `reifyReturnTypeShapes` + except that it accepts a user-provided version of operands instead + of using the normal `getOperand` way to fetch. This interface is + supposed to be used during dialect conversion (e.g. convert from + tensor world to buffer world), where `getOperand` may be invalid. + For example, some ops (e.g. dynamic_reshape(input, target_shape)) + may depend on their operands to calculate the result shape. When + the `matchAndRewrite ` method of a conversion pattern is called, + the operands of the op to convert may have been converted into other + types, which makes it invalid to call the `getOperand` method of + such op directly inside the conversion pattern. To solve this + problem, this interface follows the design of the conversion pattern, + that is, accepting passed in operands to avoid calling `getOperand` + directly inside the interface implementation. + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"reifyReturnTypeShapesWithOperands", + /*args=*/(ins "::mlir::OpBuilder&":$builder, + "::mlir::ValueRange":$operands, + "::mlir::SmallVectorImpl &":$reifiedReturnShapes), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return ::mlir::failure(); }] > ]; }