diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -215,7 +215,7 @@ AffineExpr i, r_j; bindDims(context, i, r_j); return SmallVector{ - AffineMap::get(2, 0, {i, r_j}, context), + AffineMap::get(2, 0, {i, r_j}, context), AffineMap::get(2, 0, {r_j}, context), AffineMap::get(2, 0, {i}, context) }; @@ -314,6 +314,12 @@ if (!padding().hasValue()) return 0; return padding().getValue().getValue({i, 0}); } + + int64_t getHighPad(unsigned i) { + assert(i < getNumWindowLoops()); + if (!padding().hasValue()) return 0; + return padding().getValue().getValue({i, 1}); + } }]; } @@ -357,6 +363,11 @@ unsigned getNumOutputFeatureDimensions() { return 1; } + unsigned getNumSpatialDimensions() { + return getOutputShapedType(0).getRank() - getNumBatchDimensions() - + getNumOutputFeatureDimensions(); + } + llvm::Optional> referenceIterators() { // Outer parallel loops are always the number of output dimensions; i.e. // [b, xs, q] in the TF notation above. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -337,6 +337,15 @@ : (Value)std_select(conds.back(), zero, readInput); } + /// Returns true is `convOp` has a non-zero padding. + static bool hasPadding(ConvOp convOp) { + for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { + if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) + return true; + } + return false; + } + static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) { assert(convOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); @@ -352,14 +361,19 @@ SmallVector oIdx( makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); - // Padded conv involves an affine.max in the memory access which is not - // allowed by affine.load. Override to always use an StdIndexedValue. - StdIndexedValue I(convOp.input()); IndexedValueType F(convOp.filter()), O(convOp.output()); - // Emit scalar form. - Value paddedInput = getConvOpInput(convOp, I, imIdx); - O(oIdx) += F(fIdx) * paddedInput; + // Emit scalar form. Padded conv involves an affine.max in the memory access + // which is not allowed by affine.load. Override to use an StdIndexedValue + // when there is non-zero padding. + if (hasPadding(convOp)) { + StdIndexedValue I(convOp.input()); + Value paddedInput = getConvOpInput(convOp, I, imIdx); + O(oIdx) += F(fIdx) * paddedInput; + } else { + IndexedValueType I(convOp.input()); + O(oIdx) += F(fIdx) * I(imIdx); + } } }; diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir --- a/mlir/test/Dialect/Linalg/affine.mlir +++ b/mlir/test/Dialect/Linalg/affine.mlir @@ -54,6 +54,9 @@ // CHECK: affine.for %{{.*}} = 0 to %[[Q]] { // CHECK: affine.for %{{.*}} = 0 to %[[Z0]] { // CHECK: %[[SUM:.*]] = affine.apply #[[stride2Dilation1]](%{{.*}}, %{{.*}}) +// No padding needed here; only affine loads. +// CHECK-NEXT: affine.load +// CHECK-NEXT: affine.load func @conv_padding(%arg0: memref, %arg1: memref, @@ -85,8 +88,8 @@ // CHECK: %[[SUM1:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}}) // CHECK: %[[IDX:.*]] = affine.max #[[clampMinMap]](%[[SUM0]]) // CHECK: %[[IDY:.*]] = affine.max #[[clampMinMap]](%[[SUM1]]) -// Padded conv involves an affine.max in the memory access which is not -// allowed by affine.load. Override to always use an std.load. +// Padded conv involves an affine.max in the memory access and this is not +// allowed by affine.load. Use std.load in such cases. // CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32 // CHECK: %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref