diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -62,7 +62,7 @@ /// Builds the indexing expressions for a ConvOp `op`. Returns the vector of /// AffineMaps representing: -/// `stride[i] * xs[i] + dilation[i] * zs[i]` +/// `stride[i] * xs[i] + dilation[i] * zs[i] - pad_low[i]` SmallVector weightedConvInputIndex(ConvOp op, ArrayRef xs, ArrayRef zs); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -265,13 +265,18 @@ ``` }]; - // TODO(ntv) padding. - // Following the TF source of truth above, strides and dilations are integer - // attributes of the same rank as the number of window dimensions. + // Following the TF source of truth above, strides, dilations and padding are + // integer attributes of the same rank as the number of window dimensions. + // The padding attribute specifies the amount of zero padding to be applied to + // the base area, which is a n-d array of (low, high) padding. Each pair has + // the low padding as the first element and the high padding as the second + // element. Using padding is equivalent to inserting those same zero values + // into the input before doing the convolution. let arguments = (ins AnyStridedMemRef:$filter, AnyStridedMemRef:$input, AnyStridedMemRef:$output, OptionalAttr:$strides, - OptionalAttr:$dilations); + OptionalAttr:$dilations, + OptionalAttr:$padding); let extraClassDeclaration = libraryCallName # [{ // TODO(ntv) extend to support more than 1 dimensions and potentially @@ -314,9 +319,17 @@ .cast().getValue().getSExtValue(); } - // F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) -> - // O(b, x0, ..., xN-1, k) - // for N equal to `nWindow`. + int64_t getLowPad(unsigned i) { + assert(i < getNumWindowLoops()); + if (!padding().hasValue()) return 0; + return padding().getValue().getValue({i, 0}); + } + + // F(z0, ..., zN-1, q, k) * + // I(b, x0 + z0 - pad_low_0, ..., xN-1 + zN-1 - pad_low_N-1, q) + // -> O(b, x0, ..., xN-1, k) + // for N equal to `nWindow`. If there is no padding attirbute, it will be + // ignored. llvm::Optional> referenceIndexingMaps() { MLIRContext *context = getContext(); auto nWin = getNumWindowLoops(); @@ -346,7 +359,9 @@ // filter[z[0], ..., z[N-1], q, k] AffineMap::get(idx, 0, concat(concat(zs, qs), ks)), // input[b, - // x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1], + // x[0]*s[0] + d[0]*z[0] - pad_low[0], + // ... + // x[N-1]*s[N-1] + d[N-1]*z[N-1] - pad_low[N-1], // q] AffineMap::get(idx, 0, concat(concat(bs, ws), qs)), // output[b, x[0], ..., x[N-1], k] diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -900,8 +900,12 @@ assert(xs.size() == zs.size()); SmallVector res; res.reserve(xs.size()); - for (unsigned i = 0, e = xs.size(); i < e; ++i) - res.push_back(op.getStride(i) * xs[i] + op.getDilation(i) * zs[i]); + for (unsigned i = 0, e = xs.size(); i < e; ++i) { + // TODO(ntv): add a level of indirection to linalg.generic. + auto expr = + op.getStride(i) * xs[i] + op.getDilation(i) * zs[i] - op.getLowPad(i); + res.push_back(expr); + } return res; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -152,6 +152,18 @@ "expected linalg op with buffer semantics"); assert(consumer.hasBufferSemantics() && "expected linalg op with buffer semantics"); + + if (auto convOp = dyn_cast(producer.getOperation())) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + if (auto convOp = dyn_cast(consumer.getOperation())) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + auto subView = dyn_cast_or_null( consumer.getInput(consumerIdx).getDefiningOp()); auto slice = diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -177,6 +177,51 @@ template class LinalgScopedEmitter { public: + /// Returns the input value of convOp. If the indices in `imIdx` is out of + /// boundrary, returns 0 instead. + static ValueHandle getConvOpInput(ConvOp convOp, IndexedValueType im, + ArrayRef imIdx) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (!convOp.padding()) + return im(imIdx); + + ValueHandle zeroIndex = std_constant_index(0); + SmallVector conds = { + std_constant_int(/*value=*/1, /*width=*/1)}; + SmallVector clampedImIdx; + for (auto iter : llvm::enumerate(imIdx)) { + int idx = iter.index(); + auto dim = iter.value(); + // Only need to iterate over the window dimensions. + if (idx == 0 || idx == static_cast(imIdx.size()) - 1) { + clampedImIdx.push_back(dim); + continue; + } + + using edsc::op::operator<; + using edsc::op::operator>=; + using edsc::op::operator||; + conds.push_back(conds.back() || (dim < zeroIndex)); + ValueHandle bound = std_dim(convOp.input(), idx); + conds.push_back(conds.back() || (dim >= bound)); + + // When padding is involed, the indices will only be shifted to negative, + // so having a max op is enough. + auto *context = ScopedContext::getContext(); + auto maxMap = AffineMap::get(/*dimCount=*/1, 0, + {getAffineDimExpr(/*position=*/0, context), + getAffineConstantExpr(0, context)}); + clampedImIdx.push_back( + affine_max(dim.getType(), maxMap, ValueRange{dim})); + } + + auto b = ScopedContext::getBuilder(); + Type type = convOp.input().getType().cast().getElementType(); + ValueHandle zero = std_constant(type, b.getZeroAttr(type)); + ValueHandle readInput = im(clampedImIdx); + return std_select(conds.back(), zero, readInput); + } + static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) { assert(convOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); @@ -192,8 +237,10 @@ SmallVector oIdx( makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output()); + // Emit scalar form. - O(oIdx) += F(fIdx) * I(imIdx); + ValueHandle paddedInput = getConvOpInput(convOp, I, imIdx); + O(oIdx) += F(fIdx) * paddedInput; } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -193,6 +193,12 @@ auto linalgOp = cast(op); assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); + if (auto convOp = dyn_cast(op)) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + edsc::ScopedContext scope(rewriter, op->getLoc()); if (auto fillOp = dyn_cast(op)) { @@ -295,6 +301,12 @@ assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); + if (auto convOp = dyn_cast(op)) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + LinalgOp linOp = cast(op); assert(linOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -160,6 +160,12 @@ OperationFolder *folder) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); + if (auto convOp = dyn_cast(op.getOperation())) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + // 1. Promote the specified views and use them in the new op. ScopedContext scope(b, op.getLoc()); auto promotedBufferAndViews = promoteSubViews( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -342,6 +342,12 @@ tileSizes.size() && "expected matching number of tile sizes and loops"); + if (auto convOp = dyn_cast(op.getOperation())) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + // If permutation is empty, use the identity. Build the permutation map // otherwise. auto invPermutationMap = AffineMap::getMultiDimIdentityMap( @@ -421,6 +427,12 @@ if (tileSizes.empty()) return llvm::None; + if (auto convOp = dyn_cast(op.getOperation())) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + // The following uses the convention that "tiling by zero" skips tiling a // particular dimension. This convention is significantly simpler to handle // instead of adjusting affine maps to account for missing dimensions. diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -7,6 +7,7 @@ // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> // CHECK-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> +// CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)> // CHECK-DAG: #[[Stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> // CHECK-DAG: #[[Stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)> @@ -212,6 +213,44 @@ // CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 // CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref +func @conv_padding(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], + padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>, + strides = [1, 1]} : + memref, memref, memref + return +} +// CHECK-LABEL: func @conv_padding +// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: %[[ZERO:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[Z0:.*]] = dim %arg0, 0 : memref +// CHECK: %[[Z1:.*]] = dim %arg0, 1 : memref +// CHECK: %[[Q:.*]] = dim %arg0, 2 : memref +// CHECK: %[[K:.*]] = dim %arg0, 3 : memref +// CHECK: %[[B:.*]] = dim %arg1, 0 : memref +// CHECK: %[[X0:.*]] = dim %arg2, 1 : memref +// CHECK: %[[X1:.*]] = dim %arg2, 2 : memref +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[X1]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Z0]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Z1]] step %{{.*}} { +// CHECK: %[[SUM0:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}}) +// CHECK: %[[SUM1:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}}) +// CHECK: %[[IDX:.*]] = affine.max #[[clampMinMap]](%[[SUM0]]) +// CHECK: %[[IDY:.*]] = affine.max #[[clampMinMap]](%[[SUM1]]) +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref +// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32 +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref +// CHECK: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref +// CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref + func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) { %f0 = constant 0.0 : f32 return %f0, %f0 : f32, f32 diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -222,6 +222,28 @@ // ----- +func @conv_padding(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], + padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>, + strides = [1, 1]} : + memref, memref, memref + return +} + +// CHECK-LABEL: func @conv_padding( +// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) { +// CHECK-SAME: dilations = [1, 1], +// CHECK-SAME: padding = dense<[ +// CHECK-SAME: [0, 1], [1, 1]]> : tensor<2x2xi64>, +// CHECK-SAME: strides = [1, 1]} : +// CHECK-SAME: memref, +// CHECK-SAME: memref, +// CHECK-SAME: memref + +// ----- + // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>