diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -16,6 +16,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/ShapedOpInterfaces.td" @@ -679,6 +680,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ DeclareOpInterfaceMethods, + DestinationStyleOpInterface, Pure, TypesMatchWith<"result type matches type of dest", "dest", "result", @@ -720,6 +722,12 @@ build($_builder, $_state, resType, scalar, dest, indices); }]>]; + let extraClassDeclaration = [{ + std::pair getOutputsPositionRange() { + return {1, 2}; // `dest` operand + } + }]; + let hasFolder = 1; let hasVerifier = 1; } @@ -732,6 +740,7 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AttrSizedOperandSegments, + DestinationStyleOpInterface, Pure, OffsetSizeAndStrideOpInterface, TypesMatchWith<"expected result type to match dest type", @@ -858,6 +867,10 @@ /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } + + std::pair getOutputsPositionRange() { + return {1, 2}; // `dest` operand + } }]; let hasCanonicalizer = 1; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5132,6 +5132,7 @@ deps = [ ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":DestinationStyleOpInterfaceTdFiles", ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":ParallelCombiningOpInterfaceTdFiles",