diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -16,6 +16,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/ShapedOpInterfaces.td" @@ -681,6 +682,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ DeclareOpInterfaceMethods, + DestinationStyleOpInterface, Pure, TypesMatchWith<"result type matches type of dest", "dest", "result", @@ -724,6 +726,12 @@ build($_builder, $_state, resType, scalar, dest, indices); }]>]; + let extraClassDeclaration = [{ + std::pair getOutputsPositionRange() { + return {1, 2}; // `dest` operand + } + }]; + let hasFolder = 1; let hasVerifier = 1; } @@ -736,6 +744,7 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AttrSizedOperandSegments, + DestinationStyleOpInterface, Pure, OffsetSizeAndStrideOpInterface, TypesMatchWith<"expected result type to match dest type", @@ -862,6 +871,10 @@ /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } + + std::pair getOutputsPositionRange() { + return {1, 2}; // `dest` operand + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td --- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td @@ -238,7 +238,7 @@ }] >, InterfaceMethod< - /*desc=*/"Return whether the op has only RankedTensor input and outputs.", + /*desc=*/"Return whether the op has only tensor input and outputs.", /*retTy=*/"bool", /*methodName=*/"hasTensorSemantics", /*args=*/(ins), @@ -247,7 +247,7 @@ return llvm::all_of($_op->getOpOperands(), [&](OpOperand &opOperand) { return isScalar(&opOperand) || - opOperand.get().getType().template isa(); + opOperand.get().getType().template isa(); }); }] >, diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -24,6 +24,7 @@ MLIRArithUtils MLIRCastInterfaces MLIRComplexDialect + MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRIR MLIRInferTypeOpInterface diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp --- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -31,7 +31,7 @@ Type type = operand->get().getType(); if (type.isa()) outputBufferOperands.push_back(operand); - if (type.isa()) + if (type.isa()) outputTensorOperands.push_back(operand); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5095,6 +5095,7 @@ deps = [ ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":DestinationStyleOpInterfaceTdFiles", ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":ParallelCombiningOpInterfaceTdFiles", @@ -5153,6 +5154,7 @@ ":CastOpInterfaces", ":ComplexDialect", ":ControlFlowInterfaces", + ":DestinationStyleOpInterface", ":DialectUtils", ":IR", ":InferTypeOpInterface",