diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -29,6 +29,9 @@ namespace linalg { class ConvOp; +class PoolingMaxOp; +class PoolingMinOp; +class PoolingSumOp; /// Returns the name mangled library call name to disambiguate between different /// overloads at the C level. The name mangling scheme is basic and uses MLIR @@ -60,12 +63,21 @@ SmallVector makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context); -/// Builds the indexing expressions for a ConvOp `op`. Returns the vector of -/// AffineMaps representing: -/// `stride[i] * xs[i] + dilation[i] * zs[i] - pad_low[i]` -SmallVector weightedConvInputIndex(ConvOp op, - ArrayRef xs, - ArrayRef zs); +/// Builds the indexing expressions for a ConvOp/PoolingOp `op`. Returns the +/// vector of AffineMaps representing: +/// `stride[i] * outputDims[i] + dilation[i] * windowDims[i] - pad_low[i]` +SmallVector +weightedPoolingInputIndex(ConvOp op, ArrayRef outputDims, + ArrayRef windowDims); +SmallVector +weightedPoolingInputIndex(PoolingMaxOp op, ArrayRef outputDims, + ArrayRef windowDims); +SmallVector +weightedPoolingInputIndex(PoolingMinOp op, ArrayRef outputDims, + ArrayRef windowDims); +SmallVector +weightedPoolingInputIndex(PoolingSumOp op, ArrayRef outputDims, + ArrayRef windowDims); /// Returns `maybeMap.get()` if `maybeMap` is set, otherwise returns the /// symbol-less identity map of `rank`. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -247,7 +247,69 @@ let hasFolder = 1; } -def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { +/// A base class for pooling operation such as conv. The arguments must contain +/// optional arguments `strides`, `dilations` and `padding` with following type: +/// OptionalAttr:$strides +/// OptionalAttr:$dilations +/// OptionalAttr:$padding +/// `stirdes` denotes the step of each window along the dimension. +class PoolingBase_Op props> + : LinalgStructured_Op { + let description = [{ + Performs an N-D pooling operation similarly to the description in the TF + documentation: + https://www.tensorflow.org/api_docs/python/tf/nn/pool + + Different from the description, this operation doesn't perform on batch and + channel. It only takes tensors of rank `N`. + + ``` + output[x[0], ..., x[N-1]] = + REDUCE_{z[0], ..., z[N-1]} + input[ + x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0], + ... + x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1] + ], + ``` + + The required optional arguments are: + - strides: an i64 array specifying the stride (i.e. step) for window + loops. + - dilations: an i64 array specifying the filter upsampling/input + downsampling rate + - padding: an i64 array of pairs (low, high) specifying the number of + elements to pad along a dimension. + + If strides or dilations attributes are missing then the default value is + one for each of the input dimensions. Similarly, padding values are zero + for both low and high in each of the dimensions, if not specified. + }]; + + code commonUtils = libraryCallName # [{ + int64_t getStride(unsigned i) { + assert(i < getNumWindowLoops()); + if (!strides().hasValue()) return 1; + return strides()->getValue()[i] + .cast().getValue().getSExtValue(); + } + + int64_t getDilation(unsigned i) { + assert(i < getNumWindowLoops()); + if (!dilations().hasValue()) return 1; + return dilations()->getValue()[i] + .cast().getValue().getSExtValue(); + } + + int64_t getLowPad(unsigned i) { + assert(i < getNumWindowLoops()); + if (!padding().hasValue()) return 0; + return padding().getValue().getValue({i, 0}); + } + }]; +} + +def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { let description = [{ Generic n-D convolution as described in the TF documentation: @@ -278,7 +340,7 @@ OptionalAttr:$dilations, OptionalAttr:$padding); - let extraClassDeclaration = libraryCallName # [{ + let extraClassDeclaration = commonUtils # [{ // TODO(ntv) extend to support more than 1 dimensions and potentially // grouping too. unsigned getNumBatchDimensions() { return 1; } @@ -305,26 +367,6 @@ return iters; } - int64_t getStride(unsigned i) { - assert(i < getNumWindowLoops()); - if (!strides().hasValue()) return 1; - return strides()->getValue()[i] - .cast().getValue().getSExtValue(); - } - - int64_t getDilation(unsigned i) { - assert(i < getNumWindowLoops()); - if (!dilations().hasValue()) return 1; - return dilations()->getValue()[i] - .cast().getValue().getSExtValue(); - } - - int64_t getLowPad(unsigned i) { - assert(i < getNumWindowLoops()); - if (!padding().hasValue()) return 0; - return padding().getValue().getValue({i, 0}); - } - // F(z0, ..., zN-1, q, k) * // I(b, x0 + z0 - pad_low_0, ..., xN-1 + zN-1 - pad_low_N-1, q) // -> O(b, x0, ..., xN-1, k) @@ -354,7 +396,7 @@ // Window reduction dims: sum_{z[0], ..., z[N-1], q} auto zs = makeAffineDimExprs(nWin, idx, context); // Construct the weighedSum expression. - auto ws = weightedConvInputIndex(*this, xs, zs); + auto ws = weightedPoolingInputIndex(*this, xs, zs); return SmallVector{ // filter[z[0], ..., z[N-1], q, k] AffineMap::get(idx, 0, concat(concat(zs, qs), ks)), @@ -374,6 +416,87 @@ let hasFolder = 1; } +class SingleInputPoolingBase_Op + : PoolingBase_Op, NOutputs<1>]> { + let description = [{ + Takes max op as pooling operation, i.e., it samples the maximum value in the + window. + + TODO: Figure out a better way to handle window dimensions, i.e., eliminate + the fake memref. + The window dimensions are specified by argument `windowDims`. The i-th + dimension in the shape of `windowDims` denotes the size of the window along + dimension i. For example, if the window size is 2x3, then a memref<2x3> + should be passed to the operation as `windowDims`. + }]; + + let arguments = (ins AnyStridedMemRef:$input, + AnyStridedMemRef:$windowDims, + AnyStridedMemRef:$output, + OptionalAttr:$strides, + OptionalAttr:$dilations, + OptionalAttr:$padding); + + let extraClassDeclaration = commonUtils# [{ + llvm::Optional> referenceIterators() { + // Outer parallel loops are always the number of output dimensions. + unsigned nPar = getOutputShapedType(0).getRank(); + // The window loops has the same number loops with output dimensions. + unsigned nWin = nPar; + SmallVector iters(nPar, getParallelIteratorTypeName()); + iters.reserve(nPar + nWin); + iters.append(nWin, getWindowIteratorTypeName()); + return iters; + } + + llvm::Optional> referenceIndexingMaps() { + MLIRContext *context = getContext(); + auto nPar = getNumParallelLoops(); + auto nWin = getNumWindowLoops(); + assert(nWin > 0 && "expected at least one window dimension"); + unsigned idx = 0; + auto outputDims = makeAffineDimExprs(nPar, idx, context); + auto windowDims = makeAffineDimExprs(nWin, idx, context); + // Construct the weighedSum expression. + auto inputDims = + weightedPoolingInputIndex(*this, outputDims, windowDims); + return SmallVector{ + // input + AffineMap::get(idx, 0, inputDims), + // windowDims + AffineMap::get(idx, 0, windowDims), + // output + AffineMap::get(idx, 0, outputDims) + }; + } + }]; + + let verifier = [{ return ::verify(*this); }]; + + let hasFolder = 1; +} + +def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> { + let description = [{ + Takes max op as pooling operation, i.e., it samples the maximum value in the + window. + }]; +} + +def PoolingMinOp: SingleInputPoolingBase_Op<"pooling_min"> { + let description = [{ + Takes min op as pooling operation, i.e., it samples the minimum value in the + window. + }]; +} + +def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> { + let description = [{ + Takes add op as pooling operation, i.e., it accumulates the values in the + window. + }]; +} + //===----------------------------------------------------------------------===// // Generic Linalg ops. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -72,6 +72,15 @@ /// function that implements the structured op. constexpr StringRef getLibraryCallAttrName() { return "library_call"; } +/// Attribute name for the StrArrayAttr which encodes the value of strides. +constexpr StringRef getStridesAttrName() { return "strides"; } + +/// Attribute name for the StrArrayAttr which encodes the value of dilations. +constexpr StringRef getDilationsAttrName() { return "dilations"; } + +/// Attribute name for the StrArrayAttr which encodes the value of paddings. +constexpr StringRef getPaddingAttrName() { return "padding"; } + /// Use to encode that a particular iterator type has parallel semantics. constexpr StringRef getParallelIteratorTypeName() { return "parallel"; } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -524,12 +524,21 @@ MLIRContext *ctx) { // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant // attribute values such as kernel striding and dilation. - patterns.insert, - LinalgOpConversion, LinalgOpConversion, - LinalgOpConversion, LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, LinalgOpConversion>( - ctx); + // clang-format off + patterns.insert< + CopyTransposeConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion, + LinalgOpConversion>(ctx); + // clang-format on } } // namespace diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -140,7 +140,6 @@ p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs(), attrNames); p << ": " << op.getOperandTypes(); - auto outputTensorTypes = op.getResultTypes(); if (!outputTensorTypes.empty()) p << " -> " << outputTensorTypes; @@ -827,8 +826,10 @@ return success(); } -static LogicalResult -verifyStrideOrDilation(ConvOp op, ArrayRef attrs, bool isStride) { +template +static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, + ArrayRef attrs, + bool isStride) { auto strideOrDilation = isStride ? "stride" : "dilation"; if (attrs.size() != op.getNumWindowLoops()) return op.emitOpError("expects num ") @@ -860,6 +861,41 @@ return success(); } +template +LogicalResult verifySingleInputPoolingOp(PoolingOp op) { + auto inputType = op.input().getType().template cast(); + auto outputType = op.output().getType().template cast(); + if (outputType.getElementType() != inputType.getElementType()) + return op.emitOpError("expects memref elemental types to match"); + + auto windowDimsType = op.windowDims().getType().template cast(); + if (outputType.getRank() != inputType.getRank() || + outputType.getRank() != windowDimsType.getRank()) + return op.emitOpError("expects memref ranks to match"); + + if (auto strides = op.strides()) { + if (failed( + verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) + return failure(); + } + if (auto dilations = op.dilations()) { + if (failed(verifyStrideOrDilation(op, dilations->getValue(), + /*isStride=*/false))) + return failure(); + } + return success(); +} + +static LogicalResult verify(PoolingMaxOp op) { + return verifySingleInputPoolingOp(op); +} +static LogicalResult verify(PoolingMinOp op) { + return verifySingleInputPoolingOp(op); +} +static LogicalResult verify(PoolingSumOp op) { + return verifySingleInputPoolingOp(op); +} + namespace mlir { namespace linalg { @@ -894,21 +930,50 @@ return res; } -SmallVector -mlir::linalg::weightedConvInputIndex(ConvOp op, ArrayRef xs, - ArrayRef zs) { - assert(xs.size() == zs.size()); +template +static SmallVector +weightedPoolingInputIndexImpl(LinalgOp op, ArrayRef outputDims, + ArrayRef windowDims) { + assert(outputDims.size() == windowDims.size()); SmallVector res; - res.reserve(xs.size()); - for (unsigned i = 0, e = xs.size(); i < e; ++i) { + res.reserve(outputDims.size()); + for (unsigned i = 0, e = outputDims.size(); i < e; ++i) { // TODO(ntv): add a level of indirection to linalg.generic. - auto expr = - op.getStride(i) * xs[i] + op.getDilation(i) * zs[i] - op.getLowPad(i); + auto expr = op.getStride(i) * outputDims[i] + + op.getDilation(i) * windowDims[i] - op.getLowPad(i); res.push_back(expr); } return res; } +SmallVector +mlir::linalg::weightedPoolingInputIndex(ConvOp op, + ArrayRef outputDims, + ArrayRef windowDims) { + return weightedPoolingInputIndexImpl(op, outputDims, windowDims); +} + +SmallVector +mlir::linalg::weightedPoolingInputIndex(PoolingMaxOp op, + ArrayRef outputDims, + ArrayRef windowDims) { + return weightedPoolingInputIndexImpl(op, outputDims, windowDims); +} + +SmallVector +mlir::linalg::weightedPoolingInputIndex(PoolingMinOp op, + ArrayRef outputDims, + ArrayRef windowDims) { + return weightedPoolingInputIndexImpl(op, outputDims, windowDims); +} + +SmallVector +mlir::linalg::weightedPoolingInputIndex(PoolingSumOp op, + ArrayRef outputDims, + ArrayRef windowDims) { + return weightedPoolingInputIndexImpl(op, outputDims, windowDims); +} + SmallVector mlir::linalg::concat(ArrayRef a, ArrayRef b) { auto rangeA = llvm::make_range(a.begin(), a.end()); @@ -959,6 +1024,18 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } +LogicalResult PoolingMaxOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult PoolingMinOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult PoolingSumOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} LogicalResult CopyOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -106,6 +106,23 @@ } } +// Returns a pair that contains input indices and output indices of a +// SingleInputPoolingOp `op`. +template +static std::pair, SmallVector> +getInputAndOutputIndices(ArrayRef allIvs, SingleInputPoolingOp op) { + auto &b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + auto mapsRange = op.indexing_maps().template getAsRange(); + auto maps = + functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange); + SmallVector iIdx( + makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); + SmallVector oIdx( + makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); + return {iIdx, oIdx}; +} + namespace { template class LinalgScopedEmitter {}; @@ -273,6 +290,57 @@ } }; +template +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + PoolingMaxOp op) { + auto indices = getInputAndOutputIndices(allIvs, op); + ValueHandleArray iIdx(indices.first); + ValueHandleArray oIdx(indices.second); + + // Emit scalar form. + ValueHandle lhs = std_load(op.output(), oIdx); + ValueHandle rhs = std_load(op.input(), iIdx); + using edsc::op::operator>; + ValueHandle maxValue = std_select(lhs > rhs, lhs, rhs); + std_store(maxValue, op.output(), oIdx); + } +}; + +template +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + PoolingMinOp op) { + auto indices = getInputAndOutputIndices(allIvs, op); + ValueHandleArray iIdx(indices.first); + ValueHandleArray oIdx(indices.second); + + // Emit scalar form. + ValueHandle lhs = std_load(op.output(), oIdx); + ValueHandle rhs = std_load(op.input(), iIdx); + using edsc::op::operator<; + ValueHandle minValue = std_select(lhs < rhs, lhs, rhs); + std_store(minValue, op.output(), oIdx); + } +}; + +template +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + PoolingSumOp op) { + auto indices = getInputAndOutputIndices(allIvs, op); + SmallVector iIdx = indices.first; + SmallVector oIdx = indices.second; + IndexedValueType input(op.input()), output(op.output()); + + // Emit scalar form. + output(oIdx) += input(iIdx); + } +}; + // Emits the MLIR for the scalar part of the generic op by: // 1. Emitting std_load and std_store ops for each input and output // view in order. This is achieved by applying the appropriate input or @@ -688,6 +756,9 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp) INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp) INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMaxOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp) INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -513,3 +513,14 @@ %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : memref into memref (d0 * s0 + d1)>> } + +// ----- + +func @pooling_rank_mismatch(%arg0: memref, + %arg1: memref<2x3xf32>, + %arg2: memref) { + // expected-error @+1 {{expects memref ranks to match}} + linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}: + memref, memref<2x3xf32>, memref + return +} diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -9,6 +9,7 @@ // CHECK-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> // CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)> +// CHECK-DAG: #[[Stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> // CHECK-DAG: #[[Stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> // CHECK-DAG: #[[Stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)> // CHECK-DAG: #[[Stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)> @@ -251,6 +252,75 @@ // CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 // CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref +func @pooling_max(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_max(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_max +// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref +// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref +// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref +// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} { +// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref +// CHECK: %{{.*}} = load %{{.*}}[%[[IX]], %[[IY]]] : memref +// CHECK: %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32 +// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref + +func @pooling_min(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_min(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_min +// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref +// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref +// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref +// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} { +// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref +// CHECK: %{{.*}} = load %{{.*}}[%[[IX]], %[[IY]]] : memref +// CHECK: %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32 +// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref + +func @pooling_sum(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_sum(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_sum +// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref +// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref +// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref +// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} { +// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[RHS:.*]] = load %{{.*}}[%[[IX]], %[[IY]]] : memref +// CHECK: %[[LHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref +// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32 +// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref + func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) { %f0 = constant 0.0 : f32 return %f0, %f0 : f32, f32 diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -244,6 +244,48 @@ // ----- +func @pooling_max(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_max +// CHECK: linalg.pooling_max(%{{.*}}, %{{.*}}, %{{.*}}) +// CHECK-SAME: {strides = [2, 1, 2]} +// CHECK-SAME: memref, memref, memref + +// ----- + +func @pooling_min(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_min(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_min +// CHECK: linalg.pooling_min(%{{.*}}, %{{.*}}, %{{.*}}) +// CHECK-SAME: {strides = [2, 1, 2]} +// CHECK-SAME: memref, memref, memref + +// ----- + +func @pooling_sum(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_sum(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}: + memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_sum +// CHECK: linalg.pooling_sum(%{{.*}}, %{{.*}}, %{{.*}}) +// CHECK-SAME: {strides = [2, 1, 2]} +// CHECK-SAME: memref, memref, memref + +// ----- + // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>