diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -153,6 +153,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { @@ -178,6 +179,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } /// A base class for pooling operation such as conv. The arguments must contain @@ -358,6 +360,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } class SingleInputPoolingBase_Op @@ -417,6 +420,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> { @@ -659,6 +663,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } /// GenericOp with Indexing (i.e. multi-for style in which the region is passed @@ -796,6 +801,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -818,6 +824,7 @@ let printer = [{ return ::printNamedStructuredOp(p, *this); }]; let verifier = [{ return ::verifyNamedStructuredOp(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } // This file is auto-generated from a tc specification. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1153,38 +1153,6 @@ // TODO: Consider making all this boilerplate easy to autogenerate // with Tablegen. This seems a desirable property in the context of OpInterfaces // where a Linalg "named" op **isa** LinalgOp. -LogicalResult ConvOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult PoolingMaxOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult PoolingMinOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult PoolingSumOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult CopyOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult FillOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult GenericOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult IndexedGenericOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} OpFoldResult ReshapeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); @@ -1299,58 +1267,64 @@ return verifyGenericOp(op); } +struct EraseDeadLinalgOp : public RewritePattern { + EraseDeadLinalgOp(PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + for (Value v : linalgOp.getInputsAndOutputBuffers()) { + // Linalg "inputs" may be either tensor or memref type. + // tensor<0xelt_type> is a convention that may not always mean + // "0 iterations". Only erase in cases we see memref<...x0x...>. + auto mt = v.getType().dyn_cast(); + if (!mt) + continue; + if (llvm::is_contained(mt.getShape(), 0)) { + rewriter.eraseOp(linalgOp); + return success(); + } + } + return failure(); + } +}; + +#define CANONICALIZERS_AND_FOLDERS(XXX) \ + void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ + MLIRContext *context) { \ + results.insert(); \ + } \ + \ + LogicalResult XXX::fold(ArrayRef, \ + SmallVectorImpl &) { \ + return foldMemRefCast(*this); \ + } + +CANONICALIZERS_AND_FOLDERS(ConvOp); +CANONICALIZERS_AND_FOLDERS(PoolingMaxOp); +CANONICALIZERS_AND_FOLDERS(PoolingMinOp); +CANONICALIZERS_AND_FOLDERS(PoolingSumOp); +CANONICALIZERS_AND_FOLDERS(CopyOp); +CANONICALIZERS_AND_FOLDERS(FillOp); +CANONICALIZERS_AND_FOLDERS(GenericOp); +CANONICALIZERS_AND_FOLDERS(IndexedGenericOp); + #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" // TODO: Determine whether we can generate the folders and verifiers. -LogicalResult BatchMatmulOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult DotOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult MatmulOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult MatvecOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNWCOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNCWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNHWCOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNCHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvDHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNDHWCOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNCDHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} +CANONICALIZERS_AND_FOLDERS(BatchMatmulOp); +CANONICALIZERS_AND_FOLDERS(DotOp); +CANONICALIZERS_AND_FOLDERS(MatmulOp); +CANONICALIZERS_AND_FOLDERS(MatvecOp); +CANONICALIZERS_AND_FOLDERS(ConvWOp); +CANONICALIZERS_AND_FOLDERS(ConvNWCOp); +CANONICALIZERS_AND_FOLDERS(ConvNCWOp); +CANONICALIZERS_AND_FOLDERS(ConvHWOp); +CANONICALIZERS_AND_FOLDERS(ConvNHWCOp); +CANONICALIZERS_AND_FOLDERS(ConvNCHWOp); +CANONICALIZERS_AND_FOLDERS(ConvDHWOp); +CANONICALIZERS_AND_FOLDERS(ConvNDHWCOp); +CANONICALIZERS_AND_FOLDERS(ConvNCDHWOp); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -732,19 +732,16 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, ArrayRef exprs, MLIRContext *context) { + // Size 0 corner case is useful for canonicalizations. + if (llvm::is_contained(sizes, 0)) + return getAffineConstantExpr(0, context); + + auto maps = AffineMap::inferFromExprList(exprs); + assert(!maps.empty() && "Expected one non-empty map"); + unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); + AffineExpr expr; bool dynamicPoisonBit = false; - unsigned numDims = 0; - unsigned nSymbols = 0; - // Compute the number of symbols and dimensions of the passed exprs. - for (AffineExpr expr : exprs) { - expr.walk([&numDims, &nSymbols](AffineExpr d) { - if (AffineDimExpr dim = d.dyn_cast()) - numDims = std::max(numDims, dim.getPosition() + 1); - else if (AffineSymbolExpr symbol = d.dyn_cast()) - nSymbols = std::max(nSymbols, symbol.getPosition() + 1); - }); - } int64_t runningSize = 1; for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { int64_t size = std::get<1>(en); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -172,3 +172,34 @@ // CHECK-LABEL: @no_fold_memref_reshape // CHECK: linalg.reshape // CHECK: linalg.reshape + +// ----- + +#accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)> +] + +#trait = { + args_in = 1, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel"] +} + +func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { + // memref<0x32> is expected to be dce'ed + linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32> + + // tensor<0xf32> cannot be dce'ed + %1 = linalg.generic #trait %arg1 { + ^bb(%0: f32) : + linalg.yield %0 : f32 + } : tensor<0xf32> -> tensor<0xf32> + + return %1: tensor<0xf32> +} +// CHECK-LABEL: @dce_zero_memref +// CHECK-NOT: linalg.copy +// CHECK-NEXT: linalg.generic +