diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -2,7 +2,9 @@ set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td) mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls) mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs) -mlir_tablegen(LinalgStructuredOpsInterfaces.h.inc -gen-op-interface-decls) -mlir_tablegen(LinalgStructuredOpsInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen) +set(LLVM_TARGET_DEFINITIONS LinalgStructuredOpsInterface.td) +mlir_tablegen(LinalgStructuredOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(LinalgStructuredOpsInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRLinalgStructuredOpsInterfaceIncGen) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -27,6 +27,8 @@ namespace mlir { namespace linalg { +class ConvOp; + /// Returns the name mangled library call name to disambiguate between different /// overloads at the C level. The name mangling scheme is basic and uses MLIR /// type names: @@ -52,23 +54,26 @@ /// name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl` std::string generateLibraryCallName(Operation *op); -/// Returns the list of maps that map loops to operands of a Linalg op. -/// The i-th affine map identifies loop indices to subscripts that are used when -/// accessing the i-th operand. -/// For instance, a matmul that can be written in index notation as: -/// `A(i, k) * B(k, j) -> C(i, j)` will have the following, ordered, list of -/// affine maps: -/// -/// ```mlir -/// ( -/// (i, j, k) -> (i, k), -/// (i, j, k) -> (k, j), -/// (i, j, k) -> (i, j) -/// ) -/// ``` -/// -/// Only permutation maps are currently supported. -SmallVector loopToOperandRangesMaps(Operation *op); +/// Returns `num` AffineDimExpr dimensions at positions +/// [startIdx, startIdx + num) and increments `startIdx` to `startIdx + num`. +SmallVector makeAffineDimExprs(unsigned num, unsigned &startIdx, + MLIRContext *context); + +/// Builds the indexing expressions for a ConvOp `op`. Returns the vector of +/// AffineMaps representing: +/// `stride[i] * xs[i] + dilation[i] * zs[i]` +SmallVector weightedConvInputIndex(ConvOp op, + ArrayRef xs, + ArrayRef zs); + +/// Returns `maybeMap.get()` if `maybeMap` is set, otherwise returns the +/// symbol-less identity map of `rank`. +AffineMap extractOrIdentityMap(Optional maybeMap, unsigned rank, + MLIRContext *context); + +/// Return the vector that is the concatenation of `a` and `b`. +SmallVector concat(ArrayRef a, + ArrayRef b); #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/AffineOps/AffineOpsBase.td" include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td" // The Linalg `NInputs` trait provides the API for ops that are known // to have a specified number of inputs, all passed as operands. @@ -31,184 +32,6 @@ def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">; -// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp' -// interface. -def LinalgStructuredInterface : OpInterface<"LinalgOp"> { - let methods = [ - //========================================================================// - // Loop types handling. - //========================================================================// - InterfaceMethod< - "Return the number of parallel loops within the current operation.", - "unsigned", "getNumParallelLoops" - >, - InterfaceMethod< - "Return the number of reduction loops within the current operation.", - "unsigned", "getNumReductionLoops" - >, - InterfaceMethod< - "Return the number of window loops within the current operation.", - "unsigned", "getNumWindowLoops" - >, - InterfaceMethod< - "Return the number of loops within the current operation.", - "unsigned", "getNumLoops">, - - InterfaceMethod< - [{Returns true if the current operation has only one loop and it's a - reduction loop}], - "unsigned", "hasSingleReductionLoop">, - - //========================================================================// - // Input arguments handling. - //========================================================================// - InterfaceMethod< - "Return the number of inputs from the current operation.", - "unsigned", "getNumInputs" - >, - InterfaceMethod<"Return the input view at the given index.", - "Value ", "getInput", (ins "unsigned":$i) - >, - InterfaceMethod<[{ - Return the index of the given input value `v`, or `None` if the value is - not an input. - }], - "llvm::Optional", "getIndexOfInput", (ins "Value ":$v) - >, - InterfaceMethod< - "Return the input operands from the current operation.", - "Operation::operand_range", "getInputs" - >, - InterfaceMethod<[{ - Return the type of the input shape at the given index. - }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, - InterfaceMethod<[{ - Return the subset of input operands that are of ranked tensor type. - }], "SmallVector", "getInputTensorTypes">, - - //========================================================================// - // Output arguments handling. - //========================================================================// - InterfaceMethod< - "Return the number of outputs from the current operation.", - "unsigned", "getNumOutputs" - >, - InterfaceMethod<"Return the output buffer at the given index.", - "Value ", "getOutputBuffer", (ins "unsigned":$i) - >, - InterfaceMethod<[{ - Return the index of the given buffer value, or `None` if the value is - not part of the output buffers. - }], - "llvm::Optional", "getIndexOfOutputBuffer", (ins "Value ":$view) - >, - InterfaceMethod<[{ - Return the type of the output buffer at the given index. - }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, - InterfaceMethod<[{ - Return the results that are of ranked tensor type. - }], "SmallVector", "getOutputTensorTypes">, - InterfaceMethod< - "Return the output buffers (operands) from the current operation.", - "Operation::operand_range", "getOutputBuffers" - >, - - //========================================================================// - // Input and Output arguments handling. - //========================================================================// - InterfaceMethod< - "Return the number of inputs and outputs, irrespective of their buffer " - "or tensor type.", - "unsigned", "getNumInputsAndOutputs" - >, - InterfaceMethod< - "Return the number of inputs, irrespective of their buffer or tensor " - "type, and output buffers", - "unsigned", "getNumInputsAndOutputBuffers" - >, - InterfaceMethod< - "Return the range over inputs (irrespective of type) and output buffers.", - "Operation::operand_range", "getInputsAndOutputBuffers" - >, - - //========================================================================// - // Other interface methods. - //========================================================================// - InterfaceMethod< - "Return the reference iterators for this named op (if any are specied). " - "These reference iterators are used to specify the default behavior of " - "the op. Typically this would be a static method but in order to allow " - "rank-polymorphic ops, this needs to be per object instance. Named ops " - "must define referenceIterators, even if empty for the 0-D case. " - "Generic ops on the other hand have a None `referenceIterators`", - "llvm::Optional>", "referenceIterators" - >, - InterfaceMethod< - "Return the reference indexing maps for this named op (if any are " - "specified). Typically this would be a static method but in order to " - "allow rank-polymorphic ops, this needs to be per object instance. Named " - "ops must define referenceIterators, even if empty for the 0-D case. " - "Generic ops on the other hand have a None `referenceIndexingMaps`", - "llvm::Optional>", "referenceIndexingMaps" - >, - InterfaceMethod< - "Return the iterator types attribute within the current operation.", - "ArrayAttr", "iterator_types" - >, - InterfaceMethod< - "Return the indexing maps attribute within the current operation.", - "ArrayAttr", "indexing_maps" - >, - InterfaceMethod<"Return the input or output indexing map at index `i`.", - "AffineMap", "getIndexingMap", (ins "unsigned":$i) - >, - InterfaceMethod<"Return the input indexing map at index `i`.", - "AffineMap", "getInputIndexingMap", (ins "unsigned":$i) - >, - InterfaceMethod<"Return the output indexing map at index `i`.", - "AffineMap", "getOutputIndexingMap", (ins "unsigned":$i) - >, - InterfaceMethod<[{ - Return whether the op has only MemRef input and outputs. - }], "bool", "hasBufferSemantics">, - InterfaceMethod<[{ - Return whether the op has only RankedTensor input and outputs. - }], "bool", "hasTensorSemantics">, - - //========================================================================// - // Other static interface methods. - //========================================================================// - StaticInterfaceMethod<[{ - Create an operation of the current type with the given location, - operands, and attributes. - }], - "Operation *", "create", - (ins "OpBuilder &":$builder, "Location":$loc, - "ValueRange":$operands, - "ArrayRef":$attributes), [{ - return builder.create(loc, ArrayRef{}, operands, - attributes); - }] - >, - InterfaceMethod<[{ - Clone the current operation with the given location and operands. This - is used to abstract away the optional underlying region creation. - }], - "Operation *", "clone", - (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{ - BlockAndValueMapping map; - unsigned numRegions = op.getOperation()->getNumRegions(); - Operation *res = create(b, loc, operands, op.getAttrs()); - assert(res->getNumRegions() == numRegions && "inconsistent # regions"); - for (unsigned ridx = 0; ridx < numRegions; ++ridx) - op.getOperation()->getRegion(ridx).cloneInto( - &res->getRegion(ridx), map); - return res; - }] - > - ]; -} - // Base Tablegen class for Linalg ops. // Linalg ops that correspond to library calls operate on linalg::View as their // first operands. These may be optionally followed by non-view operands @@ -228,9 +51,11 @@ let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; } -//////////////////////////////////////////////////////////////////////////////// -// Named Linalg ops, implemented as special configurations of a generic op. -//////////////////////////////////////////////////////////////////////////////// +//===----------------------------------------------------------------------===// +// Named Linalg ops, implemented as special configurations of generic ops. +//===----------------------------------------------------------------------===// +// At the moment these are not declarative and require a bunch of C++ code. +// In the future, these should be migrated to a declarative specification. def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { let description = [{ Copies the data in the input view into the output view. @@ -245,8 +70,8 @@ ```mlir %0 = linalg.dim %arg0, 0 : index loop.for %i0 = %c0 to %0 step %c1 { - %1 = linalg.load %arg0[%i0] : memref - linalg.store %1, %arg1[%i0] : memref + %1 = load %arg0[%i0] : memref + store %1, %arg1[%i0] : memref } ``` @@ -269,20 +94,22 @@ loop.for %i0 = %c0 to %{{.*}} step %c1 { loop.for %i1 = %c0 to %{{.*}} step %c1 { loop.for %i2 = %c0 to %{{.*}} step %c1 { - %3 = linalg.load %arg0[%i0, %i2, %i1] : + %3 = load %arg0[%i0, %i2, %i1] : memref - linalg.store %3, %arg1[%i2, %i1, %i0] : + store %3, %arg1[%i2, %i1, %i0] : memref ``` The views are expected to be compatible for correctness but this is not enforced at the moment. }]; + let arguments = (ins AnyStridedMemRef:$input, AnyStridedMemRef:$output, OptionalAttr:$inputPermutation, OptionalAttr:$outputPermutation); + // TODO(ntv) this should go away once the usage of OptionalAttr triggers // emission of builders with default arguments left unspecified. let builders = [OpBuilder< @@ -290,11 +117,8 @@ return build( builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; - let extraClassDeclaration = libraryCallName # [{ - // Defined in C++ for now. - // TODO(ntv): auto-generate. - ArrayAttr indexing_maps(); + let extraClassDeclaration = libraryCallName # [{ // Rank-polymorphic. // filling_value -> O(ivs) with parallel iterators. llvm::Optional> referenceIterators() { @@ -302,8 +126,16 @@ return SmallVector(nPar, getParallelIteratorTypeName()); } + // I(input_perm(ivs)) -> O(output_perm(ivs)) llvm::Optional> referenceIndexingMaps() { - llvm_unreachable("NYI referenceIndexingMaps for CopyOp"); + MLIRContext *context = getContext(); + auto maybeInputMap = inputPermutation(); + auto maybeOutputMap = outputPermutation(); + unsigned inputRank = getInputShapedType(0).getRank(); + unsigned outputRank = getOutputShapedType(0).getRank(); + return SmallVector{ + extractOrIdentityMap(maybeInputMap, inputRank, context), + extractOrIdentityMap(maybeOutputMap, outputRank, context)}; } }]; let verifier = [{ return ::verify(*this); }]; @@ -312,13 +144,10 @@ } def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRef:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); let extraClassDeclaration = libraryCallName # [{ - // Defined in C++ for now. - // TODO(ntv): auto-generate. - ArrayAttr indexing_maps(); - // Rank-polymorphic. // filling_value -> O(ivs) with parallel iterators. llvm::Optional> referenceIterators() { @@ -327,29 +156,35 @@ } llvm::Optional> referenceIndexingMaps() { - llvm_unreachable("NYI referenceIndexingMaps for CopyOp"); + MLIRContext *context = getContext(); + // filling_value -> O(ivs) + return SmallVector{ + extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}; } }]; + let verifier = [{ return ::verify(*this); }]; - + let hasFolder = 1; } def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRefOfRank<1>, AnyStridedMemRefOfRank<1>, AnyStridedMemRefOfRank<0>); - let extraClassDeclaration = libraryCallName # [{ - // Defined in C++ for now. - // TODO(ntv): auto-generate. - ArrayAttr indexing_maps(); + let extraClassDeclaration = libraryCallName # [{ llvm::Optional> referenceIterators() { return SmallVector{getReductionIteratorTypeName()}; } + // A(r_i) * B(r_i) -> C() llvm::Optional> referenceIndexingMaps() { - llvm_unreachable("NYI referenceIndexingMaps for DotOp"); + MLIRContext *context = getContext(); + auto r_i = getAffineDimExpr(0, context); + return SmallVector{ + AffineMap::get(1, 0, {r_i}), AffineMap::get(1, 0, {r_i}), AffineMap()}; } }]; @@ -357,22 +192,25 @@ } def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRefOfRank<2>, AnyStridedMemRefOfRank<1>, AnyStridedMemRefOfRank<1>); - let extraClassDeclaration = libraryCallName # [{ - // Defined in C++ for now. - // TODO(ntv): auto-generate. - ArrayAttr indexing_maps(); + let extraClassDeclaration = libraryCallName # [{ llvm::Optional> referenceIterators() { return SmallVector{ - getParallelIteratorTypeName(), - getReductionIteratorTypeName()}; + getParallelIteratorTypeName(), getReductionIteratorTypeName()}; } + // A(i, r_j) * B(r_j) -> C(i) llvm::Optional> referenceIndexingMaps() { - llvm_unreachable("NYI referenceIndexingMaps for MatvecOp"); + MLIRContext *context = getContext(); + AffineExpr i, r_j; + bindDims(context, i, r_j); + return SmallVector{ + AffineMap::get(2, 0, {i, r_j}), AffineMap::get(2, 0, {r_j}), + AffineMap::get(2, 0, {i})}; } }]; @@ -380,14 +218,12 @@ } def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRefOfRank<2>, AnyStridedMemRefOfRank<2>, AnyStridedMemRefOfRank<2>); - let extraClassDeclaration = libraryCallName # [{ - // Defined in C++ for now. - // TODO(ntv): auto-generate. - ArrayAttr indexing_maps(); + let extraClassDeclaration = libraryCallName # [{ llvm::Optional> referenceIterators() { return SmallVector{ getParallelIteratorTypeName(), @@ -395,8 +231,14 @@ getReductionIteratorTypeName()}; } + // A(i, r_k) * B(r_k, j) -> C(i, j) llvm::Optional> referenceIndexingMaps() { - llvm_unreachable("NYI referenceIndexingMaps for MatmulOp"); + MLIRContext *context = getContext(); + AffineExpr i, j, r_k; + bindDims(context, i, j, r_k); + return SmallVector{AffineMap::get(3, 0, {i, r_k}), + AffineMap::get(3, 0, {r_k, j}), + AffineMap::get(3, 0, {i, j})}; } }]; @@ -404,6 +246,7 @@ } def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { + let description = [{ Generic n-D convolution as described in the TF documentation: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution @@ -427,16 +270,15 @@ AnyStridedMemRef:$output, OptionalAttr:$strides, OptionalAttr:$dilations); + let extraClassDeclaration = libraryCallName # [{ // TODO(ntv) extend to support more than 1 dimensions and potentially // grouping too. unsigned getNumBatchDimensions() { return 1; } + unsigned getNumInputFeatureDimensions() { return 1; } - unsigned getNumOutputFeatureDimensions() { return 1; } - // Defined in C++ for now. - // TODO(ntv): auto-generate. - ArrayAttr indexing_maps(); + unsigned getNumOutputFeatureDimensions() { return 1; } llvm::Optional> referenceIterators() { // Outer parallel loops are always the number of output dimensions; i.e. @@ -470,8 +312,43 @@ .cast().getValue().getSExtValue(); } + // F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) -> + // O(b, x0, ..., xN-1, k) + // for N equal to `nWindow`. llvm::Optional> referenceIndexingMaps() { - llvm_unreachable("NYI referenceIndexingMaps for MatmulOp"); + MLIRContext *context = getContext(); + auto nWin = getNumWindowLoops(); + assert(nWin > 0 && "expected at least one window dimension"); + unsigned idx = 0; + // In the following, AffineDimExprs are indexed in loop order: + // [ b, xs, k, q, zs] + // parallels non-window reductions windows + // + // Parallel dims are exactly the dimensions indexing `output`: + // output[b, x[0], ..., x[N-1], k]; i.e. + // * batch dimensions (bs with #bs = 1 for now) + // * "image" dimensions (xs with #xs = #zs = output_rank - #bs - #ks) + // * output filter dimensions (ks with #ks = 1 for now) + auto bs = makeAffineDimExprs(getNumBatchDimensions(), idx, context); + auto xs = makeAffineDimExprs(nWin, idx, context); + auto ks = makeAffineDimExprs( + getNumOutputFeatureDimensions(), idx, context); + // Non-window reduction dim: sum_{z[0], ..., z[N-1], q} + auto qs = makeAffineDimExprs( + getNumInputFeatureDimensions(), idx, context); + // Window reduction dims: sum_{z[0], ..., z[N-1], q} + auto zs = makeAffineDimExprs(nWin, idx, context); + // Construct the weighedSum expression. + auto ws = weightedConvInputIndex(*this, xs, zs); + return SmallVector{ + // filter[z[0], ..., z[N-1], q, k] + AffineMap::get(idx, 0, concat(concat(zs, qs), ks)), + // input[b, + // x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1], + // q] + AffineMap::get(idx, 0, concat(concat(bs, ws), qs)), + // output[b, x[0], ..., x[N-1], k] + AffineMap::get(idx, 0, concat(concat(bs, xs), ks))}; } }]; @@ -480,6 +357,9 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// Generic Linalg ops. +//===----------------------------------------------------------------------===// def LinalgOperand: Type< Or<[AnyRankedTensor.predicate, AnyStridedMemRef.predicate]>>; @@ -489,9 +369,6 @@ CPred<"$_self.cast().getRank() == " # rank>] >>; -//////////////////////////////////////////////////////////////////////////////// -// Generic Linalg ops. -//////////////////////////////////////////////////////////////////////////////// class GenericOpBase : LinalgStructuredBase_Op { let arguments = (ins Variadic:$views, I64Attr:$args_in, @@ -622,12 +499,12 @@ loop.for %m = %c0 to %M step %c1 { loop.for %n = %c0 to %N step %c1 { loop.for %k = %c0 to %K step %c1 { - %a = linalg.load %A[%m, %k] : memref - %b = linalg.load %B[%k, %n] : memref - %c = linalg.load %C[%m, %n] : memref + %a = load %A[%m, %k] : memref + %b = load %B[%k, %n] : memref + %c = load %C[%m, %n] : memref %d = call @func_of_elements(%a, %b, %c) : (f32, f32, f32) -> (f32) - linalg.store %d, %C[%m, %n] : memref + store %d, %C[%m, %n] : memref } } } @@ -753,12 +630,12 @@ loop.for %m = %c0 to %M step %c1 { loop.for %n = %c0 to %N step %c1 { loop.for %k = %c0 to %K step %c1 { - %a = linalg.load %A[%m, %k] : memref - %b = linalg.load %B[%k, %n] : memref - %c = linalg.load %C[%m, %n] : memref + %a = load %A[%m, %k] : memref + %b = load %B[%k, %n] : memref + %c = load %C[%m, %n] : memref %d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c) : (index, index, index, f32, f32, f32) -> (f32) - linalg.store %d, %C[%m, %n] : memref + store %d, %C[%m, %n] : memref } } } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -0,0 +1,196 @@ +//===- LinalgStructuredInterface.td- Linalg StructuredIfce -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for the structured interface for Linalg ops. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_IR_STRUCTURED_OPS_INTERFACE +#define LINALG_IR_STRUCTURED_OPS_INTERFACE + +include "mlir/Dialect/Linalg/IR/LinalgBase.td" + +// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp' +// interface. +def LinalgStructuredInterface : OpInterface<"LinalgOp"> { + let methods = [ + //===------------------------------------------------------------------===// + // Loop types handling. + //===------------------------------------------------------------------===// + InterfaceMethod< + "Return the number of parallel loops within the current operation.", + "unsigned", "getNumParallelLoops" + >, + InterfaceMethod< + "Return the number of reduction loops within the current operation.", + "unsigned", "getNumReductionLoops" + >, + InterfaceMethod< + "Return the number of window loops within the current operation.", + "unsigned", "getNumWindowLoops" + >, + InterfaceMethod< + "Return the number of loops within the current operation.", + "unsigned", "getNumLoops">, + + InterfaceMethod< + [{Returns true if the current operation has only one loop and it's a + reduction loop}], + "bool", "hasSingleReductionLoop">, + + //===------------------------------------------------------------------===// + // Input arguments handling. + //===------------------------------------------------------------------===// + InterfaceMethod< + "Return the number of inputs from the current operation.", + "unsigned", "getNumInputs" + >, + InterfaceMethod<"Return the input view at the given index.", + "Value", "getInput", (ins "unsigned":$i) + >, + InterfaceMethod<[{ + Return the index of the given input value `v`, or `None` if the value is + not an input. + }], + "llvm::Optional", "getIndexOfInput", (ins "Value":$v) + >, + InterfaceMethod< + "Return the input operands from the current operation.", + "Operation::operand_range", "getInputs" + >, + InterfaceMethod<[{ + Return the type of the input shape at the given index. + }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Return the subset of input operands that are of ranked tensor type. + }], "SmallVector", "getInputTensorTypes">, + + //===------------------------------------------------------------------===// + // Output arguments handling. + //===------------------------------------------------------------------===// + InterfaceMethod< + "Return the number of outputs from the current operation.", + "unsigned", "getNumOutputs" + >, + InterfaceMethod<"Return the output buffer at the given index.", + "Value", "getOutputBuffer", (ins "unsigned":$i) + >, + InterfaceMethod<[{ + Return the index of the given buffer value, or `None` if the value is + not part of the output buffers. + }], + "llvm::Optional", "getIndexOfOutputBuffer", (ins "Value":$view) + >, + InterfaceMethod<[{ + Return the type of the output buffer at the given index. + }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Return the results that are of ranked tensor type. + }], "SmallVector", "getOutputTensorTypes">, + InterfaceMethod< + "Return the output buffers (operands) from the current operation.", + "Operation::operand_range", "getOutputBuffers" + >, + + //===------------------------------------------------------------------===// + // Input and Output arguments handling. + //===------------------------------------------------------------------===// + InterfaceMethod< + "Return the number of inputs and outputs, irrespective of their buffer " + "or tensor type.", + "unsigned", "getNumInputsAndOutputs" + >, + InterfaceMethod< + "Return the number of inputs, irrespective of their buffer or tensor " + "type, and output buffers", + "unsigned", "getNumInputsAndOutputBuffers" + >, + InterfaceMethod< + "Return the range over inputs (irrespective of type) and output buffers.", + "Operation::operand_range", "getInputsAndOutputBuffers" + >, + + //===------------------------------------------------------------------===// + // Other interface methods. + //===------------------------------------------------------------------===// + InterfaceMethod< + "Return the reference iterators for this named op (if any are specied). " + "These reference iterators are used to specify the default behavior of " + "the op. Typically this would be a static method but in order to allow " + "rank-polymorphic ops, this needs to be per object instance. Named ops " + "must define referenceIterators, even if empty for the 0-D case. " + "Generic ops on the other hand have a None `referenceIterators`", + "llvm::Optional>", "referenceIterators" + >, + InterfaceMethod< + "Return the reference indexing maps for this named op (if any are " + "specified). Typically this would be a static method but in order to " + "allow rank-polymorphic ops, this needs to be per object instance. Named " + "ops must define referenceIterators, even if empty for the 0-D case. " + "Generic ops on the other hand have a None `referenceIndexingMaps`", + "llvm::Optional>", "referenceIndexingMaps" + >, + InterfaceMethod< + "Return the iterator types attribute within the current operation.", + "ArrayAttr", "iterator_types" + >, + InterfaceMethod< + "Return the indexing maps attribute within the current operation.", + "ArrayAttr", "indexing_maps" + >, + InterfaceMethod<"Return the input or output indexing map at index `i`.", + "AffineMap", "getIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<"Return the input indexing map at index `i`.", + "AffineMap", "getInputIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<"Return the output indexing map at index `i`.", + "AffineMap", "getOutputIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<[{ + Return whether the op has only MemRef input and outputs. + }], "bool", "hasBufferSemantics">, + InterfaceMethod<[{ + Return whether the op has only RankedTensor input and outputs. + }], "bool", "hasTensorSemantics">, + + //===------------------------------------------------------------------===// + // Other static interface methods. + //===------------------------------------------------------------------===// + StaticInterfaceMethod<[{ + Create an operation of the current type with the given location, + operands, and attributes. + }], + "Operation *", "create", + (ins "OpBuilder &":$builder, "Location":$loc, + "ValueRange":$operands, + "ArrayRef":$attributes), [{ + return builder.create(loc, ArrayRef{}, operands, + attributes); + }] + >, + InterfaceMethod<[{ + Clone the current operation with the given location and operands. This + is used to abstract away the optional underlying region creation. + }], + "Operation *", "clone", + (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{ + BlockAndValueMapping map; + unsigned numRegions = op.getOperation()->getNumRegions(); + Operation *res = create(b, loc, operands, op.getAttrs()); + assert(res->getNumRegions() == numRegions && "inconsistent # regions"); + for (unsigned ridx = 0; ridx < numRegions; ++ridx) + op.getOperation()->getRegion(ridx).cloneInto( + &res->getRegion(ridx), map); + return res; + }] + > + ]; +} + +#endif // LINALG_IR_STRUCTURED_OPS_INTERFACE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -234,10 +234,11 @@ if (!maybeReferenceIteratorTypes && name != "generic" && name != "indexed_generic") { this->getOperation()->dump(); - llvm_unreachable("Op missing "); + llvm_unreachable("Op missing referenceIterators"); } - // If we have a reference, build the reference attribute. + // If we have a reference, build the reference attribute and set it in the + // op before returning. auto *ctx = this->getOperation()->getContext(); auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes, [ctx](StringRef str) -> Attribute { diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -19,5 +19,6 @@ ${LIBS} MLIRLinalgOpsIncGen MLIRLinalgStructuredOpsIncGen + MLIRLinalgStructuredOpsInterfaceIncGen ) target_link_libraries(MLIRLinalgOps ${LIBS}) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -866,15 +866,6 @@ return success(); } -static AffineMap extractOrIdentityMap(Optional maybeMap, - unsigned rank, MLIRContext *context) { - if (maybeMap) - return maybeMap.getValue(); - if (rank == 0) - return AffineMap(); - return AffineMap::getMultiDimIdentityMap(rank, context); -} - namespace mlir { namespace linalg { @@ -889,123 +880,43 @@ } // namespace linalg } // namespace mlir -// Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num) -// and increments `curIdx` to `curIdx + num`. -static SmallVector -makeAffineDimExprs(unsigned num, unsigned &curIdx, MLIRContext *context) { - SmallVector res; - res.reserve(num); - for (unsigned i = 0; i < num; ++i) - res.push_back(getAffineDimExpr(curIdx++, context)); - return res; +AffineMap mlir::linalg::extractOrIdentityMap(Optional maybeMap, + unsigned rank, + MLIRContext *context) { + if (maybeMap) + return maybeMap.getValue(); + if (rank == 0) + return AffineMap(); + return AffineMap::getMultiDimIdentityMap(rank, context); } -static SmallVector -weightedConvInputIndex(ConvOp op, ArrayRef a, - ArrayRef b) { - assert(a.size() == b.size()); +SmallVector +mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, + MLIRContext *context) { SmallVector res; - res.reserve(a.size()); - for (unsigned i = 0, e = a.size(); i < e; ++i) { - res.push_back(op.getStride(i) * a[i] + op.getDilation(i) * b[i]); - } + res.reserve(num); + for (unsigned i = 0; i < num; ++i) + res.push_back(getAffineDimExpr(startIdx++, context)); return res; } -static SmallVector concat(ArrayRef a, - ArrayRef b) { +SmallVector +mlir::linalg::weightedConvInputIndex(ConvOp op, ArrayRef xs, + ArrayRef zs) { + assert(xs.size() == zs.size()); SmallVector res; - res.reserve(a.size() + b.size()); - res.assign(a.begin(), a.end()); - res.append(b.begin(), b.end()); + res.reserve(xs.size()); + for (unsigned i = 0, e = xs.size(); i < e; ++i) + res.push_back(op.getStride(i) * xs[i] + op.getDilation(i) * zs[i]); return res; } -// Note: both functions below would completely disappear with a simple tensor -// kernel language. -// -// Ideally this should all be Tablegen'd but there is no good story for -// AffineMap for now. -SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { - MLIRContext *context = op->getContext(); - if (auto copyOp = dyn_cast(op)) { - // I(input_perm(ivs)) -> O(output_perm(ivs)) - auto maybeInputMap = copyOp.inputPermutation(); - auto maybeOutputMap = copyOp.outputPermutation(); - unsigned inputRank = copyOp.getInputShapedType(0).getRank(); - unsigned outputRank = copyOp.getOutputShapedType(0).getRank(); - return SmallVector{ - extractOrIdentityMap(maybeInputMap, inputRank, context), - extractOrIdentityMap(maybeOutputMap, outputRank, context)}; - } - if (auto fillOp = dyn_cast(op)) { - // filling_value -> O(ivs) - unsigned rank = fillOp.getNumParallelLoops(); - return SmallVector{ - extractOrIdentityMap(llvm::None, rank, context)}; - } - auto i = getAffineDimExpr(0, context); - auto j = getAffineDimExpr(1, context); - auto k = getAffineDimExpr(2, context); - if (isa(op)) - // A(r_i) * B(r_i) -> C() - return SmallVector{AffineMap::get(1, 0, {i}), - AffineMap::get(1, 0, {i}), AffineMap()}; - if (isa(op)) - // A(i, r_j) * B(r_j) -> C(i) - return SmallVector{AffineMap::get(2, 0, {i, j}), - AffineMap::get(2, 0, {j}), - AffineMap::get(2, 0, {i})}; - if (isa(op)) - // A(i, r_k) * B(r_k, j) -> C(i, j) - return SmallVector{AffineMap::get(3, 0, {i, k}), - AffineMap::get(3, 0, {k, j}), - AffineMap::get(3, 0, {i, j})}; - if (auto convOp = dyn_cast(op)) { - // F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) -> - // O(b, x0, ..., xN-1, k) - // for N equal to `nWindow`. - auto nWin = convOp.getNumWindowLoops(); - assert(nWin > 0 && "expected at least one window dimension"); - unsigned idx = 0; - // In the following, AffineDimExprs are indexed in loop order: - // [ b, xs, k, q, zs] - // parallels non-window reductions windows - // - // Parallel dims are exactly the dimensions indexing `output`: - // output[b, x[0], ..., x[N-1], k]; i.e. - // * batch dimensions (bs with #bs = 1 for now) - // * "image" dimensions (xs with #xs = #zs = output_rank - #bs - #ks) - // * output filter dimensions (ks with #ks = 1 for now) - auto bs = makeAffineDimExprs(convOp.getNumBatchDimensions(), idx, context); - auto xs = makeAffineDimExprs(nWin, idx, context); - auto ks = makeAffineDimExprs(convOp.getNumOutputFeatureDimensions(), idx, - context); - // Non-window reduction dim: sum_{z[0], ..., z[N-1], q} - auto qs = - makeAffineDimExprs(convOp.getNumInputFeatureDimensions(), idx, context); - // Window reduction dims: sum_{z[0], ..., z[N-1], q} - auto zs = makeAffineDimExprs(nWin, idx, context); - // Construct the weighedSum expression. - auto ws = weightedConvInputIndex(convOp, xs, zs); - return SmallVector{ - // filter[z[0], ..., z[N-1], q, k] - AffineMap::get(idx, 0, concat(concat(zs, qs), ks)), - // input[b, - // x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1], - // q] - AffineMap::get(idx, 0, concat(concat(bs, ws), qs)), - // output[b, x[0], ..., x[N-1], k] - AffineMap::get(idx, 0, concat(concat(bs, xs), ks))}; - } - SmallVector res; - auto linalgOp = cast(op); - unsigned nViews = linalgOp.getNumInputsAndOutputs(); - res.reserve(nViews); - for (unsigned i = 0, e = nViews; i < e; ++i) - res.push_back(linalgOp.getIndexingMap(i)); - assert(nViews == linalgOp.indexing_maps().size()); - return res; +SmallVector mlir::linalg::concat(ArrayRef a, + ArrayRef b) { + auto rangeA = llvm::make_range(a.begin(), a.end()); + auto rangeB = llvm::make_range(b.begin(), b.end()); + auto concatRanges = llvm::concat(rangeA, rangeB); + return llvm::to_vector<4>(concatRanges); } static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { @@ -1043,33 +954,6 @@ return ss.str(); } -static ArrayAttr getIndexingMaps(Operation *op) { - LinalgOp linalgOp = cast(op); - SmallVector maps; - maps.reserve(linalgOp.getNumInputsAndOutputs()); - for (AffineMap map : loopToOperandRangesMaps(op)) - maps.push_back(AffineMapAttr::get(map)); - return ArrayAttr::get(maps, op->getContext()); -} -ArrayAttr mlir::linalg::ConvOp::indexing_maps() { - return getIndexingMaps(getOperation()); -} -ArrayAttr mlir::linalg::CopyOp::indexing_maps() { - return getIndexingMaps(getOperation()); -} -ArrayAttr mlir::linalg::DotOp::indexing_maps() { - return getIndexingMaps(getOperation()); -} -ArrayAttr mlir::linalg::FillOp::indexing_maps() { - return getIndexingMaps(getOperation()); -} -ArrayAttr mlir::linalg::MatmulOp::indexing_maps() { - return getIndexingMaps(getOperation()); -} -ArrayAttr mlir::linalg::MatvecOp::indexing_maps() { - return getIndexingMaps(getOperation()); -} - // TODO(ntv, rriddle): Consider making all this boilerplate easy to autogenerate // with Tablegen. This seems a desirable property in the context of OpInterfaces // where a Linalg "named" op **isa** LinalgOp. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -70,7 +70,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, ArrayRef loopRanges) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); - auto maps = loopToOperandRangesMaps(op); + auto maps = op.indexing_maps(); SmallVector clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. @@ -78,7 +78,7 @@ SmallVector ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); - auto map = maps[idx]; + auto map = maps[idx].cast().getValue(); LLVM_DEBUG(dbgs() << "map: " << map << "\n"); Value view = en.value(); SmallVector viewRanges(map.getNumResults()); @@ -122,13 +122,13 @@ // the first one. static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); - auto maps = loopToOperandRangesMaps(op); + auto maps = op.indexing_maps(); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. SmallVector ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); - auto map = maps[idx]; + auto map = maps[idx].cast().getValue(); LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); Value view = en.value(); @@ -164,7 +164,9 @@ // we can always identify a data dimension with a (at least one) loop // dimension. AffineMap producerMap = - loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx]; + producer.indexing_maps()[producer.getNumInputs() + producerIdx] + .cast() + .getValue(); LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx << ", producer map: " << producerMap << "\n"); @@ -191,11 +193,9 @@ << "existing LoopRange: " << loopRanges[i] << "\n"); else { auto viewDim = getViewDefiningLoopRange(producer, i); - loopRanges[i] = SubViewOp::Range{ - folded_std_constant_index(folder, 0), - std_dim(viewDim.view, viewDim.dimension), - folded_std_constant_index(folder, 1) - }; + loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), + std_dim(viewDim.view, viewDim.dimension), + folded_std_constant_index(folder, 1)}; LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/DialectConversion.h" @@ -179,7 +180,9 @@ "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); - auto maps = loopToOperandRangesMaps(convOp); + auto mapsRange = convOp.indexing_maps().getAsRange(); + auto maps = functional::map([](AffineMapAttr a) { return a.getValue(); }, + mapsRange); SmallVector fIdx( makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); SmallVector imIdx( @@ -439,8 +442,11 @@ auto nLoops = nPar + nRed + nWin; if (!loweringIsAllowed(nPar, nLoops)) return failure(); - auto invertedMap = - inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); + auto mapsRange = + linalgOp.indexing_maps().template getAsRange(); + auto maps = + functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange); + auto invertedMap = inversePermutation(concatAffineMaps(maps)); if (!invertedMap) { LinalgScopedEmitter::emitScalarImplementation( {}, linalgOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/FoldUtils.h" @@ -263,7 +264,8 @@ Value view = *(viewIteratorBegin + viewIndex); auto viewType = view.getType().cast(); unsigned rank = viewType.getRank(); - auto map = loopToOperandRangesMaps(linalgOp)[viewIndex]; + auto mapAttr = linalgOp.indexing_maps()[viewIndex]; + auto map = mapAttr.cast().getValue(); // If the view is not tiled, we can use it as is. if (!isTiled(map, tileSizes)) { res.push_back(view); @@ -355,8 +357,10 @@ auto viewSizes = getViewSizes(b, op); // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (asserted in the inverse calculation). - auto viewSizesToLoopsMap = - inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op))); + auto mapsRange = op.indexing_maps().getAsRange(); + auto maps = + functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange); + auto viewSizesToLoopsMap = inversePermutation(concatAffineMaps(maps)); assert(viewSizesToLoopsMap && "expected invertible map"); SmallVector loopRanges;