diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" @@ -343,9 +344,12 @@ "expected matching number of tile sizes and loops"); if (auto convOp = dyn_cast(op.getOperation())) { - // TODO(ntv): add a level of indirection to linalg.generic. - if (convOp.padding()) - llvm_unreachable("Unexpected conv with padding"); + // For conv op only support tiling along batch dimension (which is the first + // loop) + if (convOp.padding() && + !llvm::all_of(tileSizes.drop_front(), + [](Value val) { return matchPattern(val, m_Zero()); })) + return llvm::None; } // If permutation is empty, use the identity. Build the permutation map @@ -427,12 +431,6 @@ if (tileSizes.empty()) return llvm::None; - if (auto convOp = dyn_cast(op.getOperation())) { - // TODO(ntv): add a level of indirection to linalg.generic. - if (convOp.padding()) - llvm_unreachable("Unexpected conv with padding"); - } - // The following uses the convention that "tiling by zero" skips tiling a // particular dimension. This convention is significantly simpler to handle // instead of adjusting affine maps to account for missing dimensions. @@ -443,6 +441,14 @@ if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; })) return llvm::None; + if (auto convOp = dyn_cast(op.getOperation())) { + // For conv op only support tiling along batch dimension (which is the first + // loop) + if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), + [](int64_t val) { return val == 0; })) + return llvm::None; + } + // Create a builder for tile size constants. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op);