diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -485,7 +485,9 @@ AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types, OptionalAttr:$doc, - OptionalAttr:$library_call); + OptionalAttr:$library_call, + Confined, + [IntMinValue<0>]>:$symbol_source); let results = (outs Variadic:$output_tensors); let regions = (region AnyRegion:$region); let extraClassDeclaration = [{ @@ -493,7 +495,7 @@ return SmallVector{ getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(), - getIteratorTypesAttrName() + getIteratorTypesAttrName(), getSymbolSourceAttrName() }; } @@ -508,12 +510,18 @@ llvm::Optional> referenceIterators() { llvm_unreachable( "No such thing as reference iterator types for a generic op."); - } + } llvm::Optional> referenceIndexingMaps() { llvm_unreachable( "No such thing as reference indexing maps for a generic op."); - } + } + + llvm::Optional getSymbolSource() { + auto ss = symbol_source(); + return ss.hasValue() ? + llvm::Optional(ss.getValue().getLimitedValue()) : llvm::None; + } }]; let printer = [{ return ::print(p, *this); }]; @@ -549,6 +557,10 @@ Each element of the list represents and iterator of one of the following types: parallel, reduction, window + - symbol_source: index of the operand whose dimensions will be propagated + as symbols to the indexing maps. When specified the number of symbols + in each of the indexing maps has to be either 0 or the rank of the + specified operand. Example: Defining a #matmul_trait attribute in MLIR can be done as follows: @@ -629,6 +641,36 @@ escape naturally. Still, transformations and rewrites that take advantage of tensor SSA values are expected to be useful and will be added in the near future. + + Example of 1D convolution with symbols: + ```mlir + #conv_1d_accesses = [ + affine_map<(m, n)[s0] -> (m + n - s0 floordiv 2)>, // in + affine_map<(m, n) -> (n)>, // filter + affine_map<(m, n) -> (m)> // out + ] + + #conv_1d_trait = { + doc = "O(m) += I(m + n - size(n) floordiv 2) * K(n)", + indexing_maps = #conv_1d_accesses, + library_call = "linalg_conv_1d", + iterator_types = ["parallel", "parallel"], + symbol_source = 1 + } + + linalg.generic #conv_1d_trait %in, %filter, %out { + ^bb0(%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b : f32 + %e = addf %c, %d : f32 + linalg.yield %e : f32 + } : memref, + memref, + memref + ``` + where symbol s0 will be substituted with `dim %filter, %c0` i.e. the first + and only dimension of the second operand as specified by the symbol_source + attribute. Note that maps requiring no symbols does not need to specify + them. }]; let builders = [ diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -46,6 +46,10 @@ return indexingMaps == maps; } +/// Attribute name for the IntegerAttr which encodes the index of operand +/// whose dimensions will be propagated as symbols to the indexing maps +constexpr StringRef getSymbolSourceAttrName() { return "symbol_source"; } + /// Attribute name for the AffineArrayAttr which encodes the relationship /// between a structured op iterators' and its operands. constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; } diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -118,6 +118,10 @@ AffineExpr replaceDimsAndSymbols(ArrayRef dimReplacements, ArrayRef symReplacements) const; + /// Replace symbols[0 .. numDims - 1] by + /// symbols[shift .. shift + numDims - 1]. + AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const; + AffineExpr operator+(int64_t v) const; AffineExpr operator+(AffineExpr other) const; AffineExpr operator-() const; diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -69,7 +69,8 @@ builder.getAffineMapArrayAttr(maps), builder.getStrArrayAttr(iteratorStrTypes), StringAttr() /*doc*/, - StringAttr() /*library_call*/ + StringAttr() /*library_call*/, + IntegerAttr() /*symbol_source*/ /* TODO: other attributes in op */ ) .getOperation(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -80,7 +80,8 @@ builder.getI64IntegerAttr(outputCount), builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), - /*doc=*/nullptr, /*library_call=*/nullptr); + /*doc=*/nullptr, /*library_call=*/nullptr, + /*symbol_source=*/nullptr); if (!bodyBuild) return; @@ -105,7 +106,8 @@ builder.getI64IntegerAttr(outputCount), builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), - /*doc=*/nullptr, /*library_call=*/nullptr); + /*doc=*/nullptr, /*library_call=*/nullptr, + /*symbol_source=*/nullptr); if (!bodyBuild) return; @@ -259,6 +261,16 @@ if (failed(BlockArgsVerifier::verify(op, region.front()))) return failure(); + auto attr = op.template getAttrOfType("symbol_source"); + int64_t targetRank = 0; + if (attr) { + unsigned index = attr.getInt(); + if (index >= op.getNumOperands()) + return op.emitOpError("symbol_source index out of range"); + targetRank = + op.getOperand(index).getType().template cast().getRank(); + } + SmallVector indexingMaps; indexingMaps.reserve(op.indexing_maps().size()); for (auto en : llvm::enumerate(op.indexing_maps())) { @@ -268,9 +280,9 @@ auto view = (idx < nInputViews) ? op.getInputShapedType(idx) : op.getOutputShapedType(idx - nInputViews); - if (m.getNumSymbols() != 0) - return op.emitOpError("expected indexing_map #") - << idx << " to have no symbols"; + if (m.getNumSymbols() != targetRank && m.getNumSymbols() != 0) + return op.emitOpError("expected the number of symbols in indexing_map #") + << idx << " to either match target rank or to be 0"; if (m.getNumDims() != nLoops) return op.emitOpError("expected indexing_map #") @@ -283,8 +295,8 @@ } auto concatMap = concatAffineMaps(indexingMaps); - auto aggregateMap = inversePermutation(concatMap); - if (!aggregateMap) + // TODO(limo): Bound inference for maps with symbols + if (!concatMap.getNumSymbols() && !inversePermutation(concatMap)) return op.emitOpError("expected the concatenation of maps in indexing_map " "to be invertible"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -319,7 +319,8 @@ genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps), genericOp.iterator_types(), /*doc = */ nullptr, - /*library_call = */ nullptr); + /*library_call = */ nullptr, + /*symbol_source = */ nullptr); rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), replacementOp.region().begin()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -510,7 +510,8 @@ rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr) + /*library_call=*/nullptr, + /*symbol_source=*/nullptr) .getOperation(); } else { fusedOp = @@ -524,7 +525,8 @@ rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr) + /*library_call=*/nullptr, + /*symbol_source=*/nullptr) .getOperation(); } @@ -787,7 +789,8 @@ rewriter.getI64IntegerAttr(consumer.getNumResults()), rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr); + /*library_call=*/nullptr, + /*symbol_source=*/nullptr); auto &fusedRegion = fusedOp.region(); rewriter.cloneRegionBefore(consumer.region(), fusedRegion, fusedRegion.begin()); @@ -843,7 +846,8 @@ rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr); + /*library_call=*/nullptr, + /*symbol_source=*/nullptr); auto &fusedRegion = fusedOp.region(); rewriter.cloneRegionBefore(producer.region(), fusedRegion, fusedRegion.begin()); @@ -893,7 +897,8 @@ rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr); + /*library_call=*/nullptr, + /*symbol_source=*/nullptr); // Map the block argument corresponding to the replaced argument with the // scalar constant. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -36,13 +36,13 @@ ArrayRef vals) { if (map.isEmpty()) return {}; - assert(map.getNumSymbols() == 0); + assert(map.getNumInputs() == vals.size()); SmallVector res; res.reserve(map.getNumResults()); auto dims = map.getNumDims(); for (auto e : map.getResults()) { - auto exprMap = AffineMap::get(dims, 0, e); + auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); SmallVector operands(vals.begin(), vals.end()); canonicalizeMapAndOperands(&exprMap, &operands); res.push_back(affine_apply(exprMap, operands)); @@ -169,8 +169,18 @@ // region has no uses. // 1.a. Emit load from input views. for (unsigned i = 0; i < nInputs; ++i) { + auto attr = linalgOp.template getAttrOfType("symbol_source"); + AffineMap map = linalgOp.getInputIndexingMap(i); + auto allIvsPlusDims = SmallVector(allIvs.begin(), allIvs.end()); + if (attr && map.getNumSymbols() > 0) { + auto operand = linalgOp.getOperand(attr.getInt()); + auto shapedType = operand.getType().template cast(); + allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank()); + for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx) + allIvsPlusDims.push_back(b.create(loc, operand, idx)); + } auto indexing = makeCanonicalAffineApplies( - b, loc, linalgOp.getInputIndexingMap(i), allIvs); + b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims); // Passing through IndexedValueType emits the proper load operation. indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); } @@ -457,7 +467,10 @@ linalgOp.indexing_maps().template getAsRange(); auto maps = llvm::to_vector<8>( llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); - AffineMap invertedMap = inversePermutation(concatAffineMaps(maps)); + AffineMap map = concatAffineMaps(maps); + // ignore symbols for now as they are not supported + AffineMap invertedMap = inversePermutation( + AffineMap::get(map.getNumDims(), 0, map.getResults(), map.getContext())); if (!invertedMap) return {}; if (invertedMap.isEmpty()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -65,7 +65,8 @@ auto linalgOp = rewriter.create( loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()), rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(), - op.iterator_types(), op.docAttr(), op.library_callAttr()); + op.iterator_types(), op.docAttr(), op.library_callAttr(), + op.symbol_sourceAttr()); // Create a new block in the region of the new Generic Op. Block &oldBlock = op.getRegion().front(); diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -93,6 +93,14 @@ llvm_unreachable("Unknown AffineExpr"); } +/// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1]. +AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const { + SmallVector symbols; + for (unsigned idx = 0; idx < numSymbols; ++idx) + symbols.push_back(getAffineSymbolExpr(idx + shift, getContext())); + return replaceDimsAndSymbols({}, symbols); +} + /// Returns true if this expression is made out of only symbols and /// constants (no dimensional identifiers). bool AffineExpr::isSymbolicOrConstant() const { diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -413,18 +413,19 @@ } AffineMap mlir::concatAffineMaps(ArrayRef maps) { - unsigned numResults = 0; + unsigned numResults = 0, numDims = 0, numSymbols = 0; for (auto m : maps) numResults += m.getNumResults(); - unsigned numDims = 0; SmallVector results; results.reserve(numResults); for (auto m : maps) { - assert(m.getNumSymbols() == 0 && "expected map without symbols"); - results.append(m.getResults().begin(), m.getResults().end()); + for (auto res : m.getResults()) + results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols)); + + numSymbols += m.getNumSymbols(); numDims = std::max(m.getNumDims(), numDims); } - return AffineMap::get(numDims, /*numSymbols=*/0, results, + return AffineMap::get(numDims, numSymbols, results, maps.front().getContext()); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -106,7 +106,7 @@ // ----- func @generic_symbol_in_map(%arg0: memref) { - // expected-error @+1 {{op expected indexing_map #0 to have no symbols}} + // expected-error @+1 {{expected the number of symbols in indexing_map #0 to either match target rank or to be 0}} linalg.generic { args_in = 0, args_out = 1, @@ -120,6 +120,22 @@ // ----- +func @generic_symbol_source_out_of_range(%arg0: memref) { + // expected-error @+1 {{symbol_source index out of range}} + linalg.generic { + args_in = 0, + args_out = 1, + indexing_maps = [ affine_map<()[N] -> (0)> ], + iterator_types = ["parallel"], + symbol_source = 1 + } %arg0 { + ^bb(%i : i32): + linalg.yield %i : i32 + }: memref +} + +// ----- + func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) { // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}} linalg.generic { diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -14,6 +14,7 @@ // CHECKLOOP-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> // CHECKLOOP-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)> // CHECKLOOP-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)> +// CHECKLOOP-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)> // CHECKPARALLEL-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECKPARALLEL-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> @@ -25,6 +26,7 @@ // CHECKPARALLEL-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> // CHECKPARALLEL-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)> // CHECKPARALLEL-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)> +// CHECKPARALLEL-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)> func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { @@ -910,3 +912,331 @@ // CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 // CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECKPARALLEL: store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref + +#conv_1d_accesses = [ + affine_map<(m, n)[s0] -> (m + n - s0 floordiv 2)>, // in + affine_map<(m, n) -> (n)>, // filter + affine_map<(m, n) -> (m)> // out +] + +#conv_1d_trait = { + args_in = 2, + args_out = 1, + doc = "C(m) += A(m) * B(n)", + indexing_maps = #conv_1d_accesses, + library_call = "linalg_conv_1d", + n_views = [2, 1], + iterator_types = ["parallel", "parallel"], + symbol_source = 1 +} + +func @conv1d(%filter : memref, %in : memref, %out : memref) -> () { + linalg.generic #conv_1d_trait %in, %filter, %out { + ^bb0(%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b : f32 + %e = addf %c, %d : f32 + linalg.yield %e : f32 + } : memref, + memref, + memref + return +} + +// CHECKLOOP-LABEL: @conv1d +// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKLOOP: %[[c0:.*]] = constant 0 : index +// CHECKLOOP: %[[dim0:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKLOOP: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKLOOP: scf.for %[[b:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} { +// CHECKLOOP: scf.for %[[m:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} { +// CHECKLOOP: %[[dim2:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim2]]] +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[aff]]] : memref +// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[m]]] : memref +// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref +// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKLOOP: store %[[res]], %[[arg2]][%[[b]]] : memref + +// CHECKPARALLEL-LABEL: @conv1d +// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index +// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%{{.*}}, %{{.*}}) to (%[[dim1]], %[[dim0]]) step ({{.*}}) { +// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim2]]] +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[aff]]] : memref +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[m]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref + +#conv_2d_accesses = [ + affine_map<(m, n, m1, n1)[s0, s1] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2)>, // in + affine_map<(m, n, m1, n1) -> (m1, n1)>, // filter + affine_map<(m, n, m1, n1) -> (m, n)> // out +] + +#conv_2d_trait = { + args_in = 2, + args_out = 1, + doc = "C(m,n) += A(m,n) * B(m1,n1)", + indexing_maps = #conv_2d_accesses, + library_call = "linalg_conv_2d", + n_views = [2, 1], + iterator_types = ["parallel", "parallel", "parallel", "parallel"], + symbol_source = 1 +} + +func @conv2d(%filter : memref, %in : memref, %out : memref) -> () { + linalg.generic #conv_2d_trait %in, %filter, %out { + ^bb0(%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b : f32 + %e = addf %c, %d : f32 + linalg.yield %e : f32 + } : memref, + memref, + memref + return +} + +// CHECKLOOP-LABEL: @conv2d +// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKLOOP: %[[c0:.*]] = constant 0 : index +// CHECKLOOP: %[[c1:.*]] = constant 1 : index +// CHECKLOOP: %[[dim0:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKLOOP: %[[dim1:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKLOOP: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} { +// CHECKLOOP: %[[dim4:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKLOOP: %[[dim5:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim4]]] +// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim5]]] +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[aff1]], %[[aff2]]] : memref +// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[i2]], %[[i3]]] : memref +// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]]] : memref +// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]]] : memref + +// CHECKPARALLEL-LABEL: @conv2d +// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index +// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index +// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step ({{.*}}) { +// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim4]]] +// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim5]]] +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[aff1]], %[[aff2]]] : memref +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[i2]], %[[i3]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]]] : memref + +#conv_3d_accesses = [ + affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2, k + k1 - s2 floordiv 2)>, // in + affine_map<(m, n, k, m1, n1, k1) -> (m1, n1, k1)>, // filter + affine_map<(m, n, k, m1, n1, k1) -> (m, n, k)> // out +] + +#conv_3d_trait = { + args_in = 2, + args_out = 1, + doc = "C(m,n,k) += A(m,n,k) * B(m1,n1,k1)", + indexing_maps = #conv_3d_accesses, + library_call = "linalg_conv_3d", + n_views = [2, 1], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"], + symbol_source = 1 +} + +func @conv3d(%filter : memref, %in : memref, %out : memref) -> () { + linalg.generic #conv_3d_trait %in, %filter, %out { + ^bb0(%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b : f32 + %e = addf %c, %d : f32 + linalg.yield %e : f32 + } : memref, + memref, + memref + return +} + +// CHECKLOOP-LABEL: @conv3d +// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKLOOP: %[[c0:.*]] = constant 0 : index +// CHECKLOOP: %[[c1:.*]] = constant 1 : index +// CHECKLOOP: %[[c2:.*]] = constant 2 : index +// CHECKLOOP: %[[dim0:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKLOOP: %[[dim1:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKLOOP: %[[dim2:.*]] = dim %[[arg0]], %[[c2]] : memref +// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref +// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim4]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim5]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i4:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i5:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} { +// CHECKLOOP: %[[dim6:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKLOOP: %[[dim7:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKLOOP: %[[dim8:.*]] = dim %[[arg0]], %[[c2]] : memref +// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim6]]] +// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim7]]] +// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]] +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[aff1]], %[[aff2]], %[[aff3]]] : memref +// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[i3]], %[[i4]], %[[i5]]] : memref +// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref +// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref + +// CHECKPARALLEL-LABEL: @conv3d +// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index +// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index +// CHECKPARALLEL: %[[c2:.*]] = constant 2 : index +// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg0]], %[[c2]] : memref +// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref +// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]], %[[i4:.*]], %[[i5:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step ({{.*}}) { +// CHECKPARALLEL: %[[dim6:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim7:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim8:.*]] = dim %[[arg0]], %[[c2]] : memref +// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim6]]] +// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim7]]] +// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]] +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[aff1]], %[[aff2]], %[[aff3]]] : memref +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[i3]], %[[i4]], %[[i5]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref + +#conv_4d_accesses = [ + affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2, k + k1 - s2 floordiv 2, l + l1 - s3 floordiv 2)>, // in + affine_map<(m, n, k, l, m1, n1, k1, l1) -> (m1, n1, k1, l1)>, // filter + affine_map<(m, n, k, l, m1, n1, k1, l1) -> (m, n, k, l)> // out +] + +#conv_4d_trait = { + args_in = 2, + args_out = 1, + doc = "C(m,n,k,l) += A(m,n,k,l) * B(m1,n1,k1,l1)", + indexing_maps = #conv_4d_accesses, + library_call = "linalg_conv_4d", + n_views = [2, 1], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"], + symbol_source = 1 +} + +func @conv4d(%filter : memref, %in : memref, %out : memref) -> () { + linalg.generic #conv_4d_trait %in, %filter, %out { + ^bb0(%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b : f32 + %e = addf %c, %d : f32 + linalg.yield %e : f32 + } : memref, + memref, + memref + return +} + +// CHECKLOOP-LABEL: @conv4d +// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKLOOP: %[[c0:.*]] = constant 0 : index +// CHECKLOOP: %[[c1:.*]] = constant 1 : index +// CHECKLOOP: %[[c2:.*]] = constant 2 : index +// CHECKLOOP: %[[c3:.*]] = constant 3 : index +// CHECKLOOP: %[[dim0:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKLOOP: %[[dim1:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKLOOP: %[[dim2:.*]] = dim %[[arg0]], %[[c2]] : memref +// CHECKLOOP: %[[dim3:.*]] = dim %[[arg0]], %[[c3]] : memref +// CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKLOOP: %[[dim6:.*]] = dim %[[arg2]], %[[c2]] : memref +// CHECKLOOP: %[[dim7:.*]] = dim %[[arg2]], %[[c3]] : memref +// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim4]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim5]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim6]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim7]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i4:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i5:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i6:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} { +// CHECKLOOP: scf.for %[[i7:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} { +// CHECKLOOP: %[[dim8:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKLOOP: %[[dim9:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKLOOP: %[[dim10:.*]] = dim %[[arg0]], %[[c2]] : memref +// CHECKLOOP: %[[dim11:.*]] = dim %[[arg0]], %[[c3]] : memref +// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]] +// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim9]]] +// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim10]]] +// CHECKLOOP: %[[aff4:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim11]]] +// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[aff1]], %[[aff2]], %[[aff3]], %[[aff4]]] : memref +// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[i4]], %[[i5]], %[[i6]], %[[i7]]] : memref +// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref +// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref + +// CHECKPARALLEL-LABEL: @conv4d +// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index +// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index +// CHECKPARALLEL: %[[c2:.*]] = constant 2 : index +// CHECKPARALLEL: %[[c3:.*]] = constant 3 : index +// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg0]], %[[c2]] : memref +// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg0]], %[[c3]] : memref +// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim6:.*]] = dim %[[arg2]], %[[c2]] : memref +// CHECKPARALLEL: %[[dim7:.*]] = dim %[[arg2]], %[[c3]] : memref +// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]], %[[i4:.*]], %[[i5:.*]], %[[i6:.*]], %[[i7:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim4]], %[[dim5]], %[[dim6]], %[[dim7]], %[[dim0]], %[[dim1]], %[[dim2]], %[[dim3]]) step ({{.*}}) { +// CHECKPARALLEL: %[[dim8:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECKPARALLEL: %[[dim9:.*]] = dim %[[arg0]], %[[c1]] : memref +// CHECKPARALLEL: %[[dim10:.*]] = dim %[[arg0]], %[[c2]] : memref +// CHECKPARALLEL: %[[dim11:.*]] = dim %[[arg0]], %[[c3]] : memref +// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]] +// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim9]]] +// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim10]]] +// CHECKPARALLEL: %[[aff4:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim11]]] +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[aff1]], %[[aff2]], %[[aff3]], %[[aff4]]] : memref +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[i4]], %[[i5]], %[[i6]], %[[i7]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -77,7 +77,8 @@ auto linalgOp = rewriter.create( loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()), rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(), - op.iterator_types(), op.docAttr(), op.library_callAttr()); + op.iterator_types(), op.docAttr(), op.library_callAttr(), + op.symbol_sourceAttr()); // Create a new block in the region of the new Generic Op. Block &oldBlock = op.getRegion().front();