diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // The Linalg `NInputs` trait provides the API for ops that are known // to have a specified number of inputs, all passed as operands. @@ -43,13 +44,14 @@ // first operands. These may be optionally followed by non-view operands // depending on the specific Linalg op. class LinalgStructuredBase_Op props> - : Op { -} + : Op {} class LinalgStructured_Op props> : LinalgStructuredBase_Op { + !listconcat(props, [ + StructuredOpTraits, + DeclareOpInterfaceMethods])> { code libraryCallName = [{ std::string getLibraryCallName() { return generateLibraryCallName(getOperation()); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1039,6 +1039,13 @@ /////// Operations corresponding to library calls defined with Tablegen //////// +void FillOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); +} + static LogicalResult verify(FillOp op) { auto viewType = op.getOutputShapedType(0); auto fillType = op.value().getType(); @@ -1047,6 +1054,15 @@ return success(); } +void CopyOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), input(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); +} + static LogicalResult verify(CopyOp op) { auto outputViewType = op.getOutputShapedType(0); auto inputViewType = op.getInputShapedType(0); @@ -1093,6 +1109,17 @@ return success(); } +void ConvOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), input(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), filter(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); +} + static LogicalResult verify(ConvOp op) { auto oType = op.output().getType().cast(); auto fType = op.filter().getType().cast(); @@ -1142,6 +1169,16 @@ return success(); } +#define DEFINE_POOLING_OP_GET_EFFECTS(OP_NAME) \ + void OP_NAME::getEffects( \ + SmallVectorImpl> \ + &effects) { \ + effects.emplace_back(MemoryEffects::Read::get(), input(), \ + SideEffects::DefaultResource::get()); \ + effects.emplace_back(MemoryEffects::Write::get(), output(), \ + SideEffects::DefaultResource::get()); \ + } + static LogicalResult verify(PoolingMaxOp op) { return verifySingleInputPoolingOp(op); } @@ -1152,6 +1189,10 @@ return verifySingleInputPoolingOp(op); } +DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp); +DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp); +DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp); + namespace { struct EraseDeadLinalgOp; struct FoldTensorCastOp; diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -334,3 +334,20 @@ return %1: tensor<3x?xf32> } + +// ----- + +// CHECK-LABEL: func @linalg_effects( +// CHECK-SAME: %[[A:[a-z0-9]]] : tensor +// CHECK-SAME: %[[B:[a-z0-9]]] : memref +// CHECK-SAME: %[[C:[a-z0-9]]] : tensor +func @linalg_effects(%a : tensor, %b : memref, %c : tensor) { + // CHECK-NOT: %{{.*}} = linalg.matmul + %t = linalg.matmul ins(%a, %b : tensor, memref) + init(%c : tensor) -> tensor + + // CHECK-NOT: %{{.*}} = linalg.matmul + linalg.matmul ins(%a, %c : tensor, tensor) + outs(%b : memref) + return +} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1451,8 +1451,9 @@ StringRef linalgOpName, ComprehensionParsingState &state) { const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [ - NamedStructuredOpTrait, AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + NamedStructuredOpTrait, SingleBlockImplicitTerminator<"YieldOp">]> { let arguments = (ins Variadic:$inputs, Variadic:$output_buffers, @@ -1589,6 +1590,24 @@ LogicalResult {0}::fold(ArrayRef, SmallVectorImpl &) {{ return foldMemRefCast(*this); + } + void {0}::getEffects(SmallVectorImpl< + SideEffects::EffectInstance >&effects) {{ + for (Value value : this->getOperation()->getResults()) {{ + effects.emplace_back(MemoryEffects::Allocate::get(), value, + SideEffects::DefaultResource::get()); + } + for (Value value : getInputBuffers()) {{ + effects.emplace_back(MemoryEffects::Read::get(), value, + SideEffects::DefaultResource::get()); + } + for (Value value : getOutputBuffers()) {{ + llvm::errs() << "\noutput: " << value; + effects.emplace_back(MemoryEffects::Read::get(), value, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), value, + SideEffects::DefaultResource::get()); + } })FMT"; os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName); }