diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -29,6 +29,7 @@ namespace linalg { class ConvOp; +class PoolingOp; /// Returns the name mangled library call name to disambiguate between different /// overloads at the C level. The name mangling scheme is basic and uses MLIR @@ -60,12 +61,15 @@ SmallVector makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context); -/// Builds the indexing expressions for a ConvOp `op`. Returns the vector of -/// AffineMaps representing: -/// `stride[i] * xs[i] + dilation[i] * zs[i] - pad_low[i]` -SmallVector weightedConvInputIndex(ConvOp op, - ArrayRef xs, - ArrayRef zs); +/// Builds the indexing expressions for a ConvOp/PoolingOp `op`. Returns the +/// vector of AffineMaps representing: +/// `stride[i] * outputDims[i] + dilation[i] * windowDims[i] - pad_low[i]` +SmallVector +weightedPoolingInputIndex(ConvOp op, ArrayRef outputDims, + ArrayRef windowDims); +SmallVector +weightedPoolingInputIndex(PoolingOp op, ArrayRef outputDims, + ArrayRef windowDims); /// Returns `maybeMap.get()` if `maybeMap` is set, otherwise returns the /// symbol-less identity map of `rank`. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -247,7 +247,56 @@ let hasFolder = 1; } -def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { +/// A base class for pooling operation such as conv. The arguments must contain +/// optional arguments `strides`, `dilations` and `padding` with following type: +/// OptionalAttr:$strides +/// OptionalAttr:$dilations +/// OptionalAttr:$padding +/// `stirdes` denotes the step of each window along the dimension. +class PoolingBase_Op props> + : LinalgStructuredBase_Op { + let description = [{ + The optional arguments are: + - strides: an i64 array specifying the stride (i.e. step) for window + loops. + - dilations: an i64 array specifying the filter upsampling/input + downsampling rate + - padding: an i64 array of pairs (low, high) specifying the number of + elements to pad along a dimension. + + If strides or dilations attributes are missing then the default value is + one for each of the input dimensions. Similarly, padding values are zero + for both low and high in each of the dimensions, if not specified. + }]; + + code commonUtils = [{ + std::string getLibraryCallName() { + return generateLibraryCallName(getOperation()); + } + + int64_t getStride(unsigned i) { + assert(i < getNumWindowLoops()); + if (!strides().hasValue()) return 1; + return strides()->getValue()[i] + .cast().getValue().getSExtValue(); + } + + int64_t getDilation(unsigned i) { + assert(i < getNumWindowLoops()); + if (!dilations().hasValue()) return 1; + return dilations()->getValue()[i] + .cast().getValue().getSExtValue(); + } + + int64_t getLowPad(unsigned i) { + assert(i < getNumWindowLoops()); + if (!padding().hasValue()) return 0; + return padding().getValue().getValue({i, 0}); + } + }]; +} + +def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { let description = [{ Generic n-D convolution as described in the TF documentation: @@ -278,7 +327,9 @@ OptionalAttr:$dilations, OptionalAttr:$padding); - let extraClassDeclaration = libraryCallName # [{ + let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; + + let extraClassDeclaration = commonUtils # [{ // TODO(ntv) extend to support more than 1 dimensions and potentially // grouping too. unsigned getNumBatchDimensions() { return 1; } @@ -305,26 +356,6 @@ return iters; } - int64_t getStride(unsigned i) { - assert(i < getNumWindowLoops()); - if (!strides().hasValue()) return 1; - return strides()->getValue()[i] - .cast().getValue().getSExtValue(); - } - - int64_t getDilation(unsigned i) { - assert(i < getNumWindowLoops()); - if (!dilations().hasValue()) return 1; - return dilations()->getValue()[i] - .cast().getValue().getSExtValue(); - } - - int64_t getLowPad(unsigned i) { - assert(i < getNumWindowLoops()); - if (!padding().hasValue()) return 0; - return padding().getValue().getValue({i, 0}); - } - // F(z0, ..., zN-1, q, k) * // I(b, x0 + z0 - pad_low_0, ..., xN-1 + zN-1 - pad_low_N-1, q) // -> O(b, x0, ..., xN-1, k) @@ -354,7 +385,7 @@ // Window reduction dims: sum_{z[0], ..., z[N-1], q} auto zs = makeAffineDimExprs(nWin, idx, context); // Construct the weighedSum expression. - auto ws = weightedConvInputIndex(*this, xs, zs); + auto ws = weightedPoolingInputIndex(*this, xs, zs); return SmallVector{ // filter[z[0], ..., z[N-1], q, k] AffineMap::get(idx, 0, concat(concat(zs, qs), ks)), @@ -374,6 +405,110 @@ let hasFolder = 1; } +def PoolingOp: PoolingBase_Op<"pooling", [NInputs<2>, NOutputs<1>]> { + let description = [{ + Applies a pooling function to all elements in each window of the input + multi-dimensional array, producing an output multi-dimensional array with + the same number of elements as the number of valid positions of the window. + + A linalg.pooling is written as: + ```mlir + linalg.pooling #trait_attribute %input, %windowDims, %output {other-attributes} : + memref, + memref, + memref + ``` + + Where #trait_attributes is an alias of a dictionary attribute described in + PoolingBase_Op. + + TODO(hanchung): Figure out a better way to handle window dimensions, i.e., + eliminate the fake memref. + The window dimensions are specified by argument `windowDims`. The i-th + dimension in the shape of `windowDims` denotes the size of the window along + dimension i. For example, if the window size is 2x3, then a memref<2x3> + should be passed to the operation as `windowDims`. + + + linalg.pooling takes a region (like generic op) as the pooling function. The + signature of the block must be: + + ``` + ^bb0([input view element type], [output view element type]) + -> ([output view element type]) + ``` + + Example: + + ```mlir + linalg.pooling { strides = [2, 1, 2] } %arg0, %arg1, %arg2 { + ^bb(%a: f32, %b: f32) : + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + }: memref, memref, memref + ``` + + }]; + + // TODO(hanchung): Consider to make pooling op take a function as argument. + let arguments = (ins AnyStridedMemRef:$input, + AnyStridedMemRef:$windowDims, + AnyStridedMemRef:$output, + OptionalAttr:$strides, + OptionalAttr:$dilations, + OptionalAttr:$padding); + + let regions = (region AnyRegion:$region); + + let extraClassDeclaration = commonUtils# [{ + SmallVector linalgTraitAttrNames() { + return SmallVector{ + getStridesAttrName(), getDilationsAttrName(), getPaddingAttrName() + }; + } + + llvm::Optional> referenceIterators() { + // Outer parallel loops are always the number of output dimensions. + unsigned nPar = getOutputShapedType(0).getRank(); + // The window loops has the same number loops with output dimensions. + unsigned nWin = nPar; + SmallVector iters(nPar, getParallelIteratorTypeName()); + iters.reserve(nPar + nWin); + iters.append(nWin, getWindowIteratorTypeName()); + return iters; + } + + llvm::Optional> referenceIndexingMaps() { + MLIRContext *context = getContext(); + auto nPar = getNumParallelLoops(); + auto nWin = getNumWindowLoops(); + assert(nWin > 0 && "expected at least one window dimension"); + unsigned idx = 0; + auto outputDims = makeAffineDimExprs(nPar, idx, context); + auto windowDims = makeAffineDimExprs(nWin, idx, context); + // Construct the weighedSum expression. + auto inputDims = + weightedPoolingInputIndex(*this, outputDims, windowDims); + return SmallVector{ + // input + AffineMap::get(idx, 0, inputDims), + // windowDims + AffineMap::get(idx, 0, windowDims), + // output + AffineMap::get(idx, 0, outputDims) + }; + } + }]; + + let verifier = [{ return ::verify(*this); }]; + + let printer = [{ return ::print(p, *this); }]; + + let parser = [{ return ::parseTraitAndRegion(parser, result); }]; + + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // Generic Linalg ops. //===----------------------------------------------------------------------===// @@ -431,7 +566,7 @@ } }]; let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parseGenericOp(parser, result); }]; + let parser = [{ return ::parseTraitAndRegion(parser, result); }]; } /// Index-free GenericOp. diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -72,6 +72,15 @@ /// function that implements the structured op. constexpr StringRef getLibraryCallAttrName() { return "library_call"; } +/// Attribute name for the StrArrayAttr which encodes the value of strides. +constexpr StringRef getStridesAttrName() { return "strides"; } + +/// Attribute name for the StrArrayAttr which encodes the value of dilations. +constexpr StringRef getDilationsAttrName() { return "dilations"; } + +/// Attribute name for the StrArrayAttr which encodes the value of paddings. +constexpr StringRef getPaddingAttrName() { return "padding"; } + /// Use to encode that a particular iterator type has parallel semantics. constexpr StringRef getParallelIteratorTypeName() { return "parallel"; } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -525,8 +525,9 @@ // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant // attribute values such as kernel striding and dilation. patterns.insert, - LinalgOpConversion, LinalgOpConversion, - LinalgOpConversion, LinalgOpConversion, + LinalgOpConversion, LinalgOpConversion, + LinalgOpConversion, LinalgOpConversion, + LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion>( ctx); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -124,8 +124,8 @@ // GenericOps //===----------------------------------------------------------------------===// -template -static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { +template +static void printTraitAndRegion(OpAsmPrinter &p, OpType op) { auto attrNames = op.linalgTraitAttrNames(); llvm::StringSet<> linalgTraitAttrsSet; linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end()); @@ -140,7 +140,11 @@ p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs(), attrNames); p << ": " << op.getOperandTypes(); +} +template +static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { + printTraitAndRegion(p, op); auto outputTensorTypes = op.getResultTypes(); if (!outputTensorTypes.empty()) p << " -> " << outputTensorTypes; @@ -152,7 +156,10 @@ printGenericOp(p, op); } -static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) { +static void print(OpAsmPrinter &p, PoolingOp op) { printTraitAndRegion(p, op); } + +static ParseResult parseTraitAndRegion(OpAsmParser &parser, + OperationState &result) { SmallVector operandsInfo, regionOperandsInfo; DictionaryAttr dictAttr; // Parse the core linalg traits that must check into a dictAttr. @@ -779,6 +786,9 @@ if (indexedGenericOp) return verifyYield(op, indexedGenericOp); + if (auto poolingOp = dyn_cast(parentOp)) + return verifyYield(op, poolingOp); + return op.emitOpError("expected '") << GenericOp::getOperationName() << "' or '" << IndexedGenericOp::getOperationName() << "' parent op"; @@ -827,8 +837,10 @@ return success(); } -static LogicalResult -verifyStrideOrDilation(ConvOp op, ArrayRef attrs, bool isStride) { +template +static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, + ArrayRef attrs, + bool isStride) { auto strideOrDilation = isStride ? "stride" : "dilation"; if (attrs.size() != op.getNumWindowLoops()) return op.emitOpError("expects num ") @@ -860,6 +872,31 @@ return success(); } +static LogicalResult verify(PoolingOp op) { + auto inputType = op.input().getType().cast(); + auto outputType = op.output().getType().cast(); + if (outputType.getElementType() != inputType.getElementType()) + return op.emitOpError("expects memref elemental types to match"); + + auto windowDimsType = op.windowDims().getType().cast(); + if (outputType.getRank() != inputType.getRank() || + outputType.getRank() != windowDimsType.getRank()) + return op.emitOpError("expects memref ranks to match"); + + if (auto strides = op.strides()) { + if (failed( + verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) + return failure(); + } + + if (op.region().empty()) + return op.emitOpError("expects a region"); + if (op.region().front().getNumArguments() != 2) + return op.emitOpError("expected number of block arguments to be 2"); + + return success(); +} + namespace mlir { namespace linalg { @@ -894,21 +931,36 @@ return res; } -SmallVector -mlir::linalg::weightedConvInputIndex(ConvOp op, ArrayRef xs, - ArrayRef zs) { - assert(xs.size() == zs.size()); +template +static SmallVector +weightedPoolingInputIndexImpl(LinalgOp op, ArrayRef outputDims, + ArrayRef windowDims) { + assert(outputDims.size() == windowDims.size()); SmallVector res; - res.reserve(xs.size()); - for (unsigned i = 0, e = xs.size(); i < e; ++i) { + res.reserve(outputDims.size()); + for (unsigned i = 0, e = outputDims.size(); i < e; ++i) { // TODO(ntv): add a level of indirection to linalg.generic. - auto expr = - op.getStride(i) * xs[i] + op.getDilation(i) * zs[i] - op.getLowPad(i); + auto expr = op.getStride(i) * outputDims[i] + + op.getDilation(i) * windowDims[i] - op.getLowPad(i); res.push_back(expr); } return res; } +SmallVector +mlir::linalg::weightedPoolingInputIndex(ConvOp op, + ArrayRef outputDims, + ArrayRef windowDims) { + return weightedPoolingInputIndexImpl(op, outputDims, windowDims); +} + +SmallVector +mlir::linalg::weightedPoolingInputIndex(PoolingOp op, + ArrayRef outputDims, + ArrayRef windowDims) { + return weightedPoolingInputIndexImpl(op, outputDims, windowDims); +} + SmallVector mlir::linalg::concat(ArrayRef a, ArrayRef b) { auto rangeA = llvm::make_range(a.begin(), a.end()); @@ -959,6 +1011,10 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } +LogicalResult PoolingOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} LogicalResult CopyOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -267,6 +267,35 @@ } }; +template +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + PoolingOp poolingOp) { + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + auto mapsRange = poolingOp.indexing_maps().getAsRange(); + auto maps = functional::map([](AffineMapAttr a) { return a.getValue(); }, + mapsRange); + SmallVector iIdx( + makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); + SmallVector oIdx( + makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); + + // 1. Emit std_load from operands. + IndexedValueType input(poolingOp.input()), output(poolingOp.output()); + ValueHandle arg0 = input(iIdx); + ValueHandle arg1 = output(oIdx); + SmallVector indexedValues = {arg0.getValue(), arg1.getValue()}; + + // 2. Inline region, currently only works for a single basic block. + ValueHandleArray indexing = {ValueHandleArray(oIdx)}; + Value outputBuffer = poolingOp.output(); + inlineRegionAndEmitStdStore(poolingOp, indexedValues, indexing, + outputBuffer); + } +}; + // Emits the MLIR for the scalar part of the generic op by: // 1. Emitting std_load and std_store ops for each input and output // view in order. This is achieved by applying the appropriate input or @@ -680,6 +709,7 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp) INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp) INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingOp) INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -513,3 +513,28 @@ %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : memref into memref (d0 * s0 + d1)>> } + +// ----- + +func @pooling_rank_mismatch(%arg0: memref, + %arg1: memref<2x3xf32>, + %arg2: memref) { + // expected-error @+1 {{expects memref ranks to match}} + linalg.pooling { strides = [2, 1, 2] } %arg0, %arg1, %arg2 { + ^bb(%a: f32, %b: f32) : + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + }: memref, memref<2x3xf32>, memref + return +} + +// ----- + +func @pooling_missing_region(%arg0: memref, + %arg1: memref<2x3x2xf32>, + %arg2: memref) { + // expected-error @+1 {{expects a region}} + linalg.pooling { strides = [2, 1, 2] } %arg0, %arg1, %arg2 : + memref, memref<2x3x2xf32>, memref + return +} diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -9,6 +9,7 @@ // CHECK-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> // CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)> +// CHECK-DAG: #[[Stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> // CHECK-DAG: #[[Stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> // CHECK-DAG: #[[Stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)> // CHECK-DAG: #[[Stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)> @@ -251,6 +252,32 @@ // CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 // CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref +func @pooling_sum(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling { strides = [2, 1] } %arg0, %arg1, %arg2 { + ^bb(%a: f32, %b: f32) : + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + }: memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_sum +// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref +// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref +// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref +// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} { +// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %[[A:.*]] = load %{{.*}}[%[[IX]], %[[IY]]] : memref +// CHECK: %[[B:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref +// CHECK: %[[RES:.*]] = addf %[[A]], %[[B]] : f32 +// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref + func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) { %f0 = constant 0.0 : f32 return %f0, %f0 : f32, f32 diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -244,6 +244,26 @@ // ----- +func @pooling_sum(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling { strides = [2, 1, 2] } %arg0, %arg1, %arg2 { + ^bb(%a: f32, %b: f32) : + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + }: memref, memref, memref + return +} +// CHECK-LABEL: func @pooling_sum +// CHECK: linalg.pooling {strides = [2, 1, 2]} +// CHECK-SAME: %{{.*}}, %{{.*}}, %{{.*}} { +// CHECK: ^{{.*}}(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32 +// CHECK: linalg.yield %[[RES]] : f32 +// CHECK: }: memref, memref, memref + +// ----- + // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>