diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -485,7 +485,9 @@ AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types, OptionalAttr:$doc, - OptionalAttr:$library_call); + OptionalAttr:$library_call, + Confined, + [IntMinValue<0>]>:$symbol_source); let results = (outs Variadic:$output_tensors); let regions = (region AnyRegion:$region); let extraClassDeclaration = [{ @@ -493,7 +495,7 @@ return SmallVector{ getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(), - getIteratorTypesAttrName() + getIteratorTypesAttrName(), getSymbolSourceAttrName() }; } @@ -514,6 +516,12 @@ llvm_unreachable( "No such thing as reference indexing maps for a generic op."); } + + llvm::Optional getSymbolSource() { + auto ss = symbol_source(); + return ss.hasValue() ? + llvm::Optional(ss.getValue().getLimitedValue()) : llvm::None; + } }]; let printer = [{ return ::print(p, *this); }]; diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -46,6 +46,10 @@ return indexingMaps == maps; } +/// Attribute name for the IntegerAttr which encodes the index of operand +/// whose dimensions will be propagated as symbols to the indexing maps +constexpr StringRef getSymbolSourceAttrName() { return "symbol_source"; } + /// Attribute name for the AffineArrayAttr which encodes the relationship /// between a structured op iterators' and its operands. constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -69,7 +69,8 @@ builder.getAffineMapArrayAttr(maps), builder.getStrArrayAttr(iteratorStrTypes), StringAttr() /*doc*/, - StringAttr() /*library_call*/ + StringAttr() /*library_call*/, + IntegerAttr() /*symbol_source*/ /* TODO: other attributes in op */ ) .getOperation(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -80,7 +80,8 @@ builder.getI64IntegerAttr(outputCount), builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), - /*doc=*/nullptr, /*library_call=*/nullptr); + /*doc=*/nullptr, /*library_call=*/nullptr, + /*symbol_source=*/nullptr); if (!bodyBuild) return; @@ -105,7 +106,8 @@ builder.getI64IntegerAttr(outputCount), builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), - /*doc=*/nullptr, /*library_call=*/nullptr); + /*doc=*/nullptr, /*library_call=*/nullptr, + /*symbol_source=*/nullptr); if (!bodyBuild) return; @@ -259,13 +261,11 @@ if (failed(BlockArgsVerifier::verify(op, region.front()))) return failure(); - auto attr = op.getAttr("symbol_source"); + auto attr = op.template getAttrOfType("symbol_source"); int64_t targetRank = 0; if (attr) { - if (attr.getKind() != StandardAttributes::Kind::Integer) - return op.emitOpError("symbol_source attribute has to be integer"); - auto index = attr.template cast().getInt(); - if (index < 0 || index >= op.getNumOperands()) + unsigned index = attr.getInt(); + if (index >= op.getNumOperands()) return op.emitOpError("symbol_source index out of range"); targetRank = op.getOperand(index).getType().template cast().getRank(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -319,7 +319,8 @@ genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps), genericOp.iterator_types(), /*doc = */ nullptr, - /*library_call = */ nullptr); + /*library_call = */ nullptr, + /*symbol_source = */ nullptr); rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), replacementOp.region().begin()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -510,7 +510,8 @@ rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr) + /*library_call=*/nullptr, + /*symbol_source=*/nullptr) .getOperation(); } else { fusedOp = @@ -524,7 +525,8 @@ rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr) + /*library_call=*/nullptr, + /*symbol_source*/ nullptr) .getOperation(); } @@ -787,7 +789,8 @@ rewriter.getI64IntegerAttr(consumer.getNumResults()), rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr); + /*library_call=*/nullptr, + /*symbol_source=*/nullptr); auto &fusedRegion = fusedOp.region(); rewriter.cloneRegionBefore(consumer.region(), fusedRegion, fusedRegion.begin()); @@ -843,7 +846,8 @@ rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr); + /*library_call=*/nullptr, + /*symbol_source=*/nullptr); auto &fusedRegion = fusedOp.region(); rewriter.cloneRegionBefore(producer.region(), fusedRegion, fusedRegion.begin()); @@ -893,7 +897,8 @@ rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, - /*library_call=*/nullptr); + /*library_call=*/nullptr, + /*symbol_source=*/nullptr); // Map the block argument corresponding to the replaced argument with the // scalar constant. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -65,7 +65,8 @@ auto linalgOp = rewriter.create( loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()), rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(), - op.iterator_types(), op.docAttr(), op.library_callAttr()); + op.iterator_types(), op.docAttr(), op.library_callAttr(), + op.symbol_sourceAttr()); // Create a new block in the region of the new Generic Op. Block &oldBlock = op.getRegion().front(); diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -120,22 +120,6 @@ // ----- -func @generic_symbol_source_wrong_type(%arg0: memref) { - // expected-error @+1 {{symbol_source attribute has to be integer}} - linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<()[N] -> (0)> ], - iterator_types = ["parallel"], - symbol_source = "none" - } %arg0 { - ^bb(%i : i32): - linalg.yield %i : i32 - }: memref -} - -// ----- - func @generic_symbol_source_out_of_range(%arg0: memref) { // expected-error @+1 {{symbol_source index out of range}} linalg.generic { @@ -152,22 +136,6 @@ // ----- -func @generic_symbol_source_out_of_range(%arg0: memref) { - // expected-error @+1 {{symbol_source index out of range}} - linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<()[N] -> (0)> ], - iterator_types = ["parallel"], - symbol_source = -1 - } %arg0 { - ^bb(%i : i32): - linalg.yield %i : i32 - }: memref -} - -// ----- - func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) { // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}} linalg.generic { diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -77,7 +77,8 @@ auto linalgOp = rewriter.create( loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()), rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(), - op.iterator_types(), op.docAttr(), op.library_callAttr()); + op.iterator_types(), op.docAttr(), op.library_callAttr(), + op.symbol_sourceAttr()); // Create a new block in the region of the new Generic Op. Block &oldBlock = op.getRegion().front();