diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -26,8 +26,8 @@ using namespace mlir; using namespace mlir::linalg; -/// Implementation of fusion of generic ops and indexed_generic ops. -static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer, +/// Conditions for elementwise fusion of generic operations. +static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, unsigned consumerIdx) { // Producer and consumer must have tensor semantics. if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) @@ -95,57 +95,20 @@ /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. static void -generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp, - LinalgOp producer, LinalgOp consumer, +generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, + GenericOp producer, GenericOp consumer, AffineMap consumerToProducerLoopsMap, unsigned consumerIdx, unsigned nloops) { // Build the region of the fused op. Block &producerBlock = producer->getRegion(0).front(); Block &consumerBlock = consumer->getRegion(0).front(); Block *fusedBlock = new Block(); - fusedOp->getRegion(0).push_back(fusedBlock); + fusedOp.region().push_back(fusedBlock); BlockAndValueMapping mapper; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(fusedBlock); - // The block arguments are - // [index_0, index_1, ... , - // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), - // producer_operand_0, ... , producer_operand_(n-1)], - // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] - // , where n is the number of producer's operand and m is the number - // consumer's operand. - // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a - // generic op. In this case, there are no indices in block arguments. - unsigned numProducerIndices = isa(producer.getOperation()) - ? producer.getNumLoops() - : 0; - unsigned numConsumerIndices = isa(consumer.getOperation()) - ? consumer.getNumLoops() - : 0; - unsigned numFusedOpIndices = - (isa(producer.getOperation()) || - isa(consumer.getOperation())) - ? std::max(producer.getNumLoops(), consumer.getNumLoops()) - : 0; - - // 0. Firstly, add all the indices to the block arguments. - for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i) - fusedBlock->addArgument(rewriter.getIndexType()); - // 1. Map consumer indices to fusedBlock indices 1-1. - mapper.map(consumerBlock.getArguments().take_front(numConsumerIndices), - fusedBlock->getArguments().take_front(numConsumerIndices)); - // 2a. Embed producer indices into fusedBlock index space 1-1. - for (auto it : - llvm::zip(producerBlock.getArguments().take_front(numProducerIndices), - fusedBlock->getArguments().take_front(numProducerIndices))) { - auto newIndex = rewriter.create( - producer.getLoc(), - consumerToProducerLoopsMap.getSubMap(std::get<0>(it).getArgNumber()), - fusedBlock->getArguments().take_front(numFusedOpIndices)); - mapper.map(std::get<0>(it), newIndex); - } - // 2b. Add an index operation for every fused loop dimension and use the + // 2. Add an index operation for every fused loop dimension and use the // `consumerToProducerLoopsMap` to map the producer indices. if (producer.hasIndexSemantics()) { // Add an index operation for every fused loop dimension. @@ -169,34 +132,30 @@ assert(consumerIdx < consumer.getNumInputs() && "expected producer of input operand"); // 3. Consumer input operands up to consumerIdx (exclusive). - for (BlockArgument bbArg : consumerBlock.getArguments() - .drop_front(numConsumerIndices) - .take_front(consumerIdx)) // input assumption. + for (BlockArgument bbArg : consumerBlock.getArguments().take_front( + consumerIdx)) // input assumption. mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); // Replacing consumerIdx requires getting the cloned, yielded, value from // the (cloned) producer block. This happens in step 9. // 4. Splice in producer's input operands. - for (BlockArgument bbArg : producerBlock.getArguments() - .drop_front(numProducerIndices) - .take_front(producer.getNumInputs())) + for (BlockArgument bbArg : + producerBlock.getArguments().take_front(producer.getNumInputs())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); // 4.b. Producer output operand/map that is fused needs to be mapped to the // producer bbArg if it is an "initTensor" (i.e. its value is actually read). assert(producer->getNumResults() == 1 && "expected single result producer"); if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) { - BlockArgument bbArg = - producerBlock.getArguments() - .drop_front(numConsumerIndices + producer.getNumInputs()) - // TODO: bbArg index of - .front(); + BlockArgument bbArg = producerBlock.getArguments() + .drop_front(producer.getNumInputs()) + // TODO: bbArg index of + .front(); mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); } // 5. Remaining consumer's input operands (drop past index `consumerIdx`). for (BlockArgument bbArg : consumerBlock.getArguments() - .drop_front(numConsumerIndices) .take_front(consumer.getNumInputs()) .drop_front(consumerIdx + 1)) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); @@ -232,23 +191,21 @@ assert(!producer->isAncestor(replacement.getDefiningOp()) && "yielded value must have been mapped"); } - mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices), - replacement); + mapper.map(consumerBlock.getArgument(consumerIdx), replacement); // 10. Clone operations from the consumer to the fused op. for (auto &op : consumerBlock.getOperations()) rewriter.clone(op, mapper); // Sanity checks. - assert(fusedBlock->getNumArguments() == - fusedOp->getNumOperands() + numFusedOpIndices && - "Ill-formed LinalgOp region"); + assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && + "Ill-formed GenericOp region"); } -static Optional> -fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand, +static Optional> +fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand, const ControlElementwiseOpsFusionFn &controlFn, PatternRewriter &rewriter) { - LinalgOp consumer = cast(consumerOpOperand.getOwner()); + auto consumer = cast(consumerOpOperand.getOwner()); unsigned consumerIdx = consumerOpOperand.getOperandNumber(); if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) || !controlFn(producer->getResult(0), consumerOpOperand)) @@ -311,27 +268,14 @@ assert(producer->getNumResults() == 1 && "expected single result producer"); // Generate the fused op. - Operation *fusedOp; - if (isa(producer.getOperation()) && - isa(consumer.getOperation())) { - fusedOp = rewriter.create( - consumer.getLoc(), consumer->getResultTypes(), - /*inputs=*/fusedOperands, - // TODO: handle outputs. - consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr); - } else { - fusedOp = rewriter.create( - consumer.getLoc(), consumer->getResultTypes(), - /*inputs=*/fusedOperands, - // TODO: handle outputs. - consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr); - } + auto fusedOp = rewriter.create( + consumer.getLoc(), consumer->getResultTypes(), + /*inputs=*/fusedOperands, + // TODO: handle outputs. + consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); // Construct an AffineMap from consumer loops to producer loops. // consumer loop -> tensor index @@ -348,7 +292,7 @@ generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer, consumerToProducerLoopsMap, consumerIdx, consumer.getNumLoops()); - return SmallVector(fusedOp->getResults()); + return SmallVector(fusedOp->getResults()); } /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` @@ -373,7 +317,7 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap, ArrayRef sourceShape, ArrayRef reassociationMaps) { - SmallVector resultExprs; + SmallVector resultExprs; resultExprs.reserve(reassociationMaps.size()); ArrayRef sourceExprs = sourceMap.getResults(); MLIRContext *context = sourceMap.getContext(); @@ -386,8 +330,8 @@ assert(!collapsedDims.empty()); unsigned startDim = collapsedDims.front().cast().getPosition(); - SmallVector sizes; - SmallVector dimExprs; + SmallVector sizes; + SmallVector dimExprs; for (auto en : llvm::zip(sourceShape.slice(startDim, collapsedDims.size()), sourceExprs.slice(startDim, collapsedDims.size()))) { @@ -426,22 +370,6 @@ return useIndexMap.isPermutation(); } -/// Based on the type of `op` create a linalg op of the same type, i.e. if `op` -/// is a linalg.generic operation, the create a `linalg.generic` operation with -/// the given `args`. Expects `op` to be `linalg.generic` or -/// `linalg.indexed_generic`. -template -static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter, - Args... args) { - if (isa(op.getOperation())) - return rewriter.create(args...); - if (isa(op.getOperation())) - return rewriter.create(args...); - llvm_unreachable( - "expected only linalg.generic or linalg.indexed_generic ops"); - return nullptr; -} - /// Check if the reshape operation is only expansion into/collapsing of /// unit-dimension. static bool isUnitDimExpansionOnly(ArrayRef expandedShape, @@ -459,10 +387,10 @@ return true; } -/// Conditions for folding a generic/indexed-generic operation with a reshape op -/// by expanding the iteration space dimensionality for tensor operations. These -/// are preconditions assumed by `foldReshapeByDimExpansion` which implements -/// the following fusion pattern. +/// Conditions for folding a generic operation with a reshape op by expanding +/// the iteration space dimensionality for tensor operations. These are +/// preconditions assumed by `foldReshapeByDimExpansion` which implements the +/// following fusion pattern. /// /// Consider /// @@ -476,12 +404,12 @@ /// affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] /// : tensor into tensor /// -/// The reshape can be folded into the `linalgOp` if the -/// generic/indexed-generic op loop dimensionality is increased to match the -/// result (operand) of the tensor_reshape when the reshape is expanding -/// (folding). The indexing_map of the fused tensor in the `linalgOp` and the -/// reassociation map helps compute the indexing maps of the modified op. For -/// the above example, based on the reassociation map it can be concluded that +/// The reshape can be folded into the `genericOp` if its loop dimensionality +/// is increased to match the result (operand) of the tensor_reshape when the +/// reshape is expanding (folding). The indexing_map of the fused tensor in the +/// `genericOp` and the reassociation map helps compute the indexing maps of +/// the modified op. For the above example, based on the reassociation map it +/// can be concluded that /// /// - The loop used to access the first dimension of the fused tensor is split /// into two. @@ -520,41 +448,40 @@ /// /// The added reshapes are again expanding patterns, so they will get fused /// with its producers if possible. -static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, +static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, unsigned fusedTensorIndex) { // Is fusable only if: - // - The linalgOp is a generic op, or an indexed_generic. - // - All the indexing maps for operands and results in linalgOp are projected + // - All the indexing maps for operands and results are projected // permutations. // - The fused tensor is not a scalar. - // - All the loops in linalgOp are parallel loops. - return isa(linalgOp.getOperation()) && - linalgOp.hasTensorSemantics() && - llvm::all_of(linalgOp.indexing_maps().getValue(), + // - All the loops are parallel loops. + return genericOp.hasTensorSemantics() && + llvm::all_of(genericOp.indexing_maps().getValue(), [](Attribute attr) { return attr.cast() .getValue() .isProjectedPermutation(); }) && - linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 && - llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) { + genericOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 && + llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { return attr.cast().getValue() == getParallelIteratorTypeName(); }); } namespace { -/// Information needed to expand a generic/indexed_generic operation to fold the -/// reshape with it. +/// Information needed to expand a generic operation to fold the reshape with +/// it. class ExpansionInfo { public: // Computes the mapping from original dimensions of the op to the dimensions // of the expanded op given the `indexingMap` of the fused operand/result of - // the generic/indexed_generic op, the `reassocationMaps` of the reshape op - // and the shape of the expanded op. + // the generic op, the `reassocationMaps` of the reshape op and the shape of + // the expanded op. LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex, ArrayRef reassociationMaps, - ArrayRef expandedShape); + ArrayRef expandedShape, + PatternRewriter &rewriter); unsigned getOrigOpNumDims() const { return reassociation.size(); } unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } ReassociationIndicesRef getExpandedDims(unsigned i) const { @@ -567,10 +494,10 @@ private: /// Reassociation from the dimensions in the original operation to the /// dimension of the expanded operation. - SmallVector reassociation; + SmallVector reassociation; /// Mapping from extent of loops in the original operation, to the extent of /// loops in the expanded operation. - SmallVector, 4> expandedShapeMap; + SmallVector> expandedShapeMap; unsigned expandedOpNumDims; }; } // namespace @@ -578,7 +505,8 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, unsigned fusedTensorIndex, ArrayRef reassociationMaps, - ArrayRef expandedShape) { + ArrayRef expandedShape, + PatternRewriter &rewriter) { if (reassociationMaps.empty()) return failure(); AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); @@ -586,13 +514,13 @@ Optional> originalLoopRange = linalgOp.getStaticLoopRanges(); if (!originalLoopRange) - return linalgOp.emitError("unable to find loop range for operation"); + return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); reassociation.clear(); expandedShapeMap.clear(); // Compute the number of dimension in the expanded op that correspond to each // dimension of the original op. - SmallVector numExpandedDims(fusedIndexMap.getNumDims(), 1); + SmallVector numExpandedDims(fusedIndexMap.getNumDims(), 1); expandedShapeMap.resize(fusedIndexMap.getNumDims()); for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { unsigned pos = resultExpr.value().cast().getPosition(); @@ -627,17 +555,19 @@ /// Note that this could be extended to handle dynamic case, but the /// implementation below uses `affine.apply` which seems to have issues when the /// shapes are not static. -LogicalResult isIndexedOpExpandable(LinalgOp linalgOp, - const ExpansionInfo &expansionInfo) { +LogicalResult isGenericOpExpandable(GenericOp genericOp, + const ExpansionInfo &expansionInfo, + PatternRewriter &rewriter) { + if (!genericOp.hasIndexSemantics()) + return success(); for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { ArrayRef expandedShape = expansionInfo.getExpandedShapeOfDim(i); if (expandedShape.size() == 1) continue; for (int64_t shape : expandedShape.drop_front()) { if (ShapedType::isDynamic(shape)) { - return linalgOp.emitError( - "unable to fuse indexed generic op where the expanded dim is " - "dynamic"); + return rewriter.notifyMatchFailure( + genericOp, "cannot expand due to index semantics and dynamic dims"); } } } @@ -649,7 +579,7 @@ static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo) { - SmallVector newExprs; + SmallVector newExprs; for (AffineExpr expr : indexingMap.getResults()) { unsigned pos = expr.cast().getPosition(); SmallVector expandedExprs = llvm::to_vector<4>( @@ -668,7 +598,7 @@ static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo) { - SmallVector expandedShape; + SmallVector expandedShape; for (AffineExpr expr : indexingMap.getResults()) { unsigned dim = expr.cast().getPosition(); auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); @@ -682,15 +612,15 @@ /// the expanded operation. The same method is used to compute the /// `linalg.tensor_reshape` used to collapse the result of the expanded op to /// get the value that can replace all uses of the results of the original op. -static SmallVector +static SmallVector getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo) { - SmallVector reassociation; + SmallVector reassociation; unsigned numReshapeDims = 0; for (AffineExpr expr : indexingMap.getResults()) { unsigned dim = expr.cast().getPosition(); auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); - auto indices = llvm::to_vector<2>( + SmallVector indices = llvm::to_vector<2>( llvm::seq(numReshapeDims, numReshapeDims + numExpandedDims)); reassociation.emplace_back(std::move(indices)); numReshapeDims += numExpandedDims; @@ -698,66 +628,14 @@ return reassociation; } -/// Build the body of the expanded IndexedGenericOp. The arguments for the -/// induction variables of the original operation need to be recovered by -/// linearizing the arguments of the corresponding dimensions of the expanded -/// op. For now it is assumed that the shapes of the expanded op needed for -/// linearization are static. -static void buildExpandedIndexedGenericOpRegion( - PatternRewriter &rewriter, Location loc, Region &originalOpRegion, - Region &fusedOpRegion, const ExpansionInfo &expansionInfo) { - assert(fusedOpRegion.empty() && "expected fused op to have empty region"); - // Create an entry block in the fused region with same number of arguments - // as the fused op - Block *fusedEntryBlock = new Block; - fusedOpRegion.push_back(fusedEntryBlock); - rewriter.cloneRegionBefore(originalOpRegion, fusedOpRegion, - fusedOpRegion.end()); - - // Merge the entry block of the fused op with the cloned blocks. For this - // compute the value for arguments of the region in the original operation - // in terms of the arguments of the fused op. Since the original operation - // is expanded, the expanded dimensions need to be folded back to get the - // replacement value for the arguments corresponding to interation index. - // For now this expects that all the loop ranges are constants, which is - // true if the shapes are all static. This has already been checked in the - // precondition. - using namespace edsc::op; - using namespace edsc::intrinsics; - OpBuilder::InsertionGuard guard(rewriter); - SmallVector argReplacements(originalOpRegion.getNumArguments()); - rewriter.setInsertionPointToStart(fusedEntryBlock); - edsc::ScopedContext scopedContext(rewriter, loc); - IndexType indexType = rewriter.getIndexType(); - for (auto i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { - Value linearizedIndex = fusedEntryBlock->addArgument(indexType); - ArrayRef expandedDimsShape = - expansionInfo.getExpandedShapeOfDim(i).drop_front(); - for (unsigned shape : expandedDimsShape) { - assert(!ShapedType::isDynamic(shape)); - linearizedIndex = linearizedIndex * std_constant_index(shape); - linearizedIndex = - linearizedIndex + fusedEntryBlock->addArgument(indexType); - } - argReplacements[i] = linearizedIndex; - } - for (auto i : llvm::seq(expansionInfo.getOrigOpNumDims(), - argReplacements.size())) { - argReplacements[i] = - fusedEntryBlock->addArgument(originalOpRegion.getArgument(i).getType()); - } - rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock, - argReplacements); -} - /// Update the body of an expanded linalg operation having index semantics. The /// indices of the original operation need to be recovered by linearizing the /// indices of the correspoding dimensions of the expanded operation. For now it /// is assumed that the shapes of the expanded operation needed for /// linearization are static. -static void updateExpandedIndexOpRegion(PatternRewriter &rewriter, Location loc, - Region &fusedRegion, - const ExpansionInfo &expansionInfo) { +static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, + Location loc, Region &fusedRegion, + const ExpansionInfo &expansionInfo) { // Replace the original indices by the linearization of the expanded indices. for (IndexOp indexOp : llvm::make_early_inc_range(fusedRegion.front().getOps())) { @@ -793,112 +671,100 @@ } } -/// Implements the fusion of a tensor_reshape op and a generic/indexed_generic -/// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those -/// conditions have been satisfied. -static Optional> -fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, +/// Implements the fusion of a tensor_reshape op and a generic op as explained +/// in `isFusableWithReshapeByExpansion`. Assumes that those conditions have +/// been satisfied. +static Optional> +fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp, unsigned fusedTensorIndex, PatternRewriter &rewriter) { - assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) && + assert(isFusableWithReshapeByDimExpansion(genericOp, fusedTensorIndex) && "preconditions for fuse operation failed"); // Check if reshape is expanding or collapsing. bool isExpanding = reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); RankedTensorType expandedType = isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); - bool hasIndexSemantics = linalgOp.hasIndexSemantics() || - isa(linalgOp.getOperation()); ExpansionInfo expansionInfo; - if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex, + if (failed(expansionInfo.compute(genericOp, fusedTensorIndex, reshapeOp.getReassociationMaps(), - expandedType.getShape()))) + expandedType.getShape(), rewriter))) return llvm::None; - if (hasIndexSemantics && - failed(isIndexedOpExpandable(linalgOp, expansionInfo))) + if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) return llvm::None; SmallVector expandedOpIndexingMaps = llvm::to_vector<4>( - llvm::map_range(linalgOp.getIndexingMaps(), [&](AffineMap m) { + llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); })); - SmallVector expandedOpOperands; - for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + SmallVector expandedOpOperands; + for (auto operand : llvm::enumerate(genericOp.getInputs())) { if (operand.index() == fusedTensorIndex) { expandedOpOperands.push_back(reshapeOp.src()); continue; } - AffineMap indexingMap = linalgOp.getInputIndexingMap(operand.index()); + AffineMap indexingMap = genericOp.getInputIndexingMap(operand.index()); RankedTensorType expandedOperandType = getExpandedType(operand.value().getType().cast(), indexingMap, expansionInfo); if (expandedOperandType != operand.value().getType()) { // Reshape the operand to get the right type. - SmallVector reassociation = + SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); expandedOpOperands.push_back(rewriter.create( - linalgOp.getLoc(), expandedOperandType, operand.value(), + genericOp.getLoc(), expandedOperandType, operand.value(), reassociation)); continue; } expandedOpOperands.push_back(operand.value()); } - Location loc = linalgOp.getLoc(); - SmallVector outputs; - for (auto result : llvm::enumerate(linalgOp.getOutputs())) { - AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index()); + Location loc = genericOp.getLoc(); + SmallVector outputs; + for (auto result : llvm::enumerate(genericOp.getOutputs())) { + AffineMap indexingMap = genericOp.getOutputIndexingMap(result.index()); RankedTensorType expandedOutputType = getExpandedType(result.value().getType().cast(), indexingMap, expansionInfo); if (expandedOutputType != result.value().getType()) { - SmallVector reassociation = + SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); outputs.push_back(rewriter.create( - linalgOp.getLoc(), expandedOutputType, result.value(), + genericOp.getLoc(), expandedOutputType, result.value(), reassociation)); } } // The iterator types of the expanded op are all parallel. - SmallVector iteratorTypes(expansionInfo.getExpandedOpNumDims(), - getParallelIteratorTypeName()); + SmallVector iteratorTypes(expansionInfo.getExpandedOpNumDims(), + getParallelIteratorTypeName()); TypeRange resultTypes = ValueRange(outputs).getTypes(); - LinalgOp fusedOp = createLinalgOpOfSameType( - linalgOp, rewriter, linalgOp.getLoc(), resultTypes, - /*inputs=*/expandedOpOperands, outputs, expandedOpIndexingMaps, - iteratorTypes); + auto fusedOp = + rewriter.create(genericOp.getLoc(), resultTypes, + /*inputs=*/expandedOpOperands, outputs, + expandedOpIndexingMaps, iteratorTypes); Region &fusedRegion = fusedOp->getRegion(0); - Region &originalRegion = linalgOp->getRegion(0); - - if (isa(linalgOp.getOperation())) { - rewriter.cloneRegionBefore(originalRegion, fusedRegion, - fusedRegion.begin()); - } else { - assert(isa(linalgOp.getOperation())); - buildExpandedIndexedGenericOpRegion(rewriter, loc, originalRegion, - fusedRegion, expansionInfo); - } + Region &originalRegion = genericOp->getRegion(0); + rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); // Update the index accesses after the expansion. - if (linalgOp.hasIndexSemantics()) - updateExpandedIndexOpRegion(rewriter, loc, fusedRegion, expansionInfo); + updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); // Reshape the result values to their original shape if this is a collapsing // reshape folded into its consumer. - SmallVector resultVals; - for (auto result : llvm::enumerate(linalgOp->getResults())) { + SmallVector resultVals; + for (auto result : llvm::enumerate(genericOp->getResults())) { if (!isExpanding && resultTypes[result.index()] != result.value().getType()) { - SmallVector reassociation = + SmallVector reassociation = getReassociationForExpansion( - linalgOp.getOutputIndexingMap(result.index()), expansionInfo); + genericOp.getOutputIndexingMap(result.index()), expansionInfo); resultVals.push_back(rewriter.create( - linalgOp.getLoc(), result.value().getType(), + genericOp.getLoc(), result.value().getType(), fusedOp->getResult(result.index()), reassociation)); } else { resultVals.push_back(fusedOp->getResult(result.index())); @@ -934,22 +800,21 @@ /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } /// ins(%arg0, %arg1 : tensor, tensor) ... /// -> tensor -template +template struct FoldProducerReshapeOpByLinearization - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(LinalgOpTy op, + LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!op.hasTensorSemantics()) + if (!genericOp.hasTensorSemantics()) return failure(); - LinalgOp linalgOp = cast(op.getOperation()); - for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + for (auto operand : llvm::enumerate(genericOp.getInputs())) { TensorReshapeOp reshapeOp = operand.value().getDefiningOp(); if (!reshapeOp || !isTensorReshapeOpFoldableByLinearization( - reshapeOp, linalgOp.getInputIndexingMap(operand.index()), + reshapeOp, genericOp.getInputIndexingMap(operand.index()), /*asProducer =*/true) || (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), @@ -957,15 +822,15 @@ continue; // Compute the fused operands list, - SmallVector fusedOperands(linalgOp.getInputs()); + SmallVector fusedOperands(genericOp.getInputs()); fusedOperands[operand.index()] = reshapeOp.src(); - fusedOperands.append(linalgOp.getOutputs().begin(), - linalgOp.getOutputs().end()); + fusedOperands.append(genericOp.getOutputs().begin(), + genericOp.getOutputs().end()); // Compute indexing_maps for the fused operation. The indexing_maps for // the operands of the consumers that arent fused are the same. SmallVector fusedIndexMaps = llvm::to_vector<4>( - op.indexing_maps().template getAsValueRange()); + genericOp.indexing_maps().template getAsValueRange()); // Accepted consumer maps are either identity or permutation. auto invMap = inversePermutation(fusedIndexMaps[operand.index()]); @@ -984,13 +849,14 @@ // inverted. Without this the resultant op is not legal. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { return rewriter.notifyMatchFailure( - op, "fused op loop bound computation failed"); + genericOp, "fused op loop bound computation failed"); } - rewriter.startRootUpdate(op); - op->setOperands(fusedOperands); - op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps)); - rewriter.finalizeRootUpdate(op); + rewriter.startRootUpdate(genericOp); + genericOp->setOperands(fusedOperands); + genericOp.indexing_mapsAttr( + rewriter.getAffineMapArrayAttr(fusedIndexMaps)); + rewriter.finalizeRootUpdate(genericOp); return success(); } return failure(); @@ -1013,7 +879,7 @@ /// Pattern to move rank reducing reshape after an elementwise linalg generic /// op. This is useful to expose more fusion opportunities between named ops and -/// generic op. This can only be done if there is no broadcast or permuation +/// generic ops. This can only be done if there is no broadcast or permuation /// within the dimensions we need to merge. /// /// For example, @@ -1040,27 +906,27 @@ /// %3 = linalg.tensor_reshape %2 [ /// #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] /// : tensor<12544x16xf32> into tensor<112x112x16xf32> -template -struct PushExpandingReshape : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct PushExpandingReshape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOpTy op, + LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Only apply to elementwise linalg on tensor. - if (!op.hasTensorSemantics() || - op.getNumParallelLoops() != op.getNumLoops()) + if (!genericOp.hasTensorSemantics() || + genericOp.getNumParallelLoops() != genericOp.getNumLoops()) return failure(); // Only support identity output maps. It could be extended to permuations if // needed. - if (llvm::any_of(op.getOutputIndexingMaps(), + if (llvm::any_of(genericOp.getOutputIndexingMaps(), [](AffineMap map) { return !map.isIdentity(); })) return failure(); - int64_t destRank = op.getNumParallelLoops(); - SmallVector newOperands = llvm::to_vector<4>(op.getInputs()); + int64_t destRank = genericOp.getNumParallelLoops(); + SmallVector newOperands = + llvm::to_vector<4>(genericOp.getInputs()); TensorReshapeOp reshapeFound; // 1. Look for tensor_reshape operands and figure out save the dimensions // merged. - for (auto operand : llvm::enumerate(op.getInputs())) { + for (auto operand : llvm::enumerate(genericOp.getInputs())) { TensorReshapeOp reshapeOp = operand.value().template getDefiningOp(); if (!reshapeOp || reshapeOp.getSrcType().getRank() > @@ -1069,7 +935,7 @@ } // TODO: We could support non-identity map as long as the merged // dimensions are still contiguous. - if (!op.getIndexingMaps()[operand.index()].isIdentity()) + if (!genericOp.getIndexingMaps()[operand.index()].isIdentity()) continue; if (reshapeFound) { // Only support a second reshape op if it has the same reassociate maps. @@ -1087,7 +953,7 @@ // Calculate the reassociation indices and rassociated reverse map. SmallVector reassociation = getReassociationIndices(reshapeFound.getReassociationMaps()); - SmallVector remap(destRank); + SmallVector remap(destRank); for (auto &indices : llvm::enumerate(reassociation)) { for (int64_t index : indices.value()) { remap[index] = indices.index(); @@ -1096,9 +962,9 @@ // 2. Verify that we can merge the dimensions in the linalg and that we // don't need to create new reshapes operands. Inserting new reshape // operands would defeat the purpose of the transformation. - for (auto operand : llvm::enumerate(op.getInputs())) { + for (auto operand : llvm::enumerate(genericOp.getInputs())) { if (operand.value() == newOperands[operand.index()]) { - AffineMap map = op.getIndexingMaps()[operand.index()]; + AffineMap map = genericOp.getIndexingMaps()[operand.index()]; for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { if (reassociation[remap[map.getDimPosition(i)]].size() > 1) return failure(); @@ -1108,70 +974,69 @@ // 3. Calculate the affine map remapping and the reassociation to apply to // output tensors. - SmallVector newMaps; + SmallVector newMaps; unsigned newRank = reassociation.size(); - for (auto map : op.getIndexingMaps()) { + for (auto map : genericOp.getIndexingMaps()) { SmallVector newExprs; for (auto expr : map.getResults()) { unsigned position = expr.template cast().getPosition(); // Skip dimension merged except for the last of the group. if (reassociation[remap[position]].back() == position) { newExprs.push_back( - getAffineDimExpr(remap[position], op.getContext())); + getAffineDimExpr(remap[position], genericOp.getContext())); } } - newMaps.push_back(AffineMap::get(newRank, 0, newExprs, op.getContext())); + newMaps.push_back( + AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); } // 4. Reshape the output tensors. SmallVector newOutputs; SmallVector newOutputTypes; - for (auto output : op.outputs()) { + for (auto output : genericOp.outputs()) { auto newOutputType = RankedTensorType::get( reshapeFound.getSrcType().getShape(), output.getType().template cast().getElementType()); Value newOutput = rewriter.create( - op->getLoc(), newOutputType, output, reassociation); + genericOp->getLoc(), newOutputType, output, reassociation); newOutputTypes.push_back(newOutputType); newOutputs.push_back(newOutput); } // 5. Create a new generic op with lowerer rank. - SmallVector iteratorTypes(newRank, - getParallelIteratorTypeName()); - auto newOp = - rewriter.create(op->getLoc(), newOutputTypes, newOperands, - newOutputs, newMaps, iteratorTypes); - rewriter.inlineRegionBefore(op.region(), newOp.region(), + SmallVector iteratorTypes(newRank, + getParallelIteratorTypeName()); + auto newOp = rewriter.create(genericOp->getLoc(), newOutputTypes, + newOperands, newOutputs, newMaps, + iteratorTypes); + rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), newOp.region().begin()); // 6. Reshape the so that the type matches the uses. SmallVector newResults; for (auto result : llvm::enumerate(newOp->getResults())) { newResults.push_back(rewriter.create( - op->getLoc(), op.getOutputTensorTypes()[result.index()], + genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], result.value(), reassociation)); } - rewriter.replaceOp(op, newResults); + rewriter.replaceOp(genericOp, newResults); return success(); } }; -/// Pattern to fuse a tensor_reshape op with its consumer -/// generic/indexed_generic op, when the reshape op is collapsing -/// dimensions. The dimensionality of the loop in the consumer is expanded. -template +/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the +/// reshape op is collapsing dimensions. The dimensionality of the loop in the +/// consumer is expanded. class FoldWithProducerReshapeOpByExpansion - : public OpRewritePattern { + : public OpRewritePattern { public: FoldWithProducerReshapeOpByExpansion( MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : OpRewritePattern(context, benefit), controlFoldingReshapes(foldReshapes) {} - LogicalResult matchAndRewrite(GenericOpTy genericOp, + LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - LinalgOp linalgOp = cast(genericOp.getOperation()); - for (auto operand : llvm::enumerate(linalgOp.getInputs())) { + for (auto operand : llvm::enumerate(genericOp.getInputs())) { TensorReshapeOp reshapeOp = operand.value().getDefiningOp(); if (!reshapeOp) @@ -1181,14 +1046,14 @@ // - All constraints of fusing with reshape by expansion are met. if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || - !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || + !isFusableWithReshapeByDimExpansion(genericOp, operand.index()) || (!controlFoldingReshapes( reshapeOp->getResult(0), - linalgOp.getInputOpOperands()[operand.index()]))) + genericOp.getInputOpOperands()[operand.index()]))) continue; - Optional> replacementValues = - fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(), + Optional> replacementValues = + fuseWithReshapeByExpansion(genericOp, reshapeOp, operand.index(), rewriter); if (!replacementValues) return failure(); @@ -1211,10 +1076,9 @@ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { - LinalgOp producer = reshapeOp.src().getDefiningOp(); - if (!producer || - !isa(producer.getOperation()) || - !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 || + GenericOp producer = reshapeOp.src().getDefiningOp(); + if (!producer || !producer.hasTensorSemantics() || + producer.getNumOutputs() != 1 || !isTensorReshapeOpFoldableByLinearization( reshapeOp, producer.getOutputIndexingMap(0), /*asProducer =*/false) || @@ -1251,8 +1115,8 @@ Location loc = producer.getLoc(); Value output = rewriter.create( loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs()); - LinalgOp fusedOp = createLinalgOpOfSameType( - producer, rewriter, loc, reshapeOp.getResultType(), + auto fusedOp = rewriter.create( + loc, reshapeOp.getResultType(), /*inputs=*/producer.getInputs(), // TODO: handle outputs. /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), @@ -1280,16 +1144,15 @@ // - All constraints of fusing with reshape by expansion are met. if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank()) return failure(); - LinalgOp producer = reshapeOp.src().getDefiningOp(); + GenericOp producer = reshapeOp.src().getDefiningOp(); if (!producer || producer.getNumOutputs() != 1 || !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs()) || isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), reshapeOp.getReassociationMaps())) return failure(); - Optional> replacementValues = - fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(), - rewriter); + Optional> replacementValues = fuseWithReshapeByExpansion( + producer, reshapeOp, producer.getNumInputs(), rewriter); if (!replacementValues) return failure(); rewriter.replaceOp(reshapeOp, replacementValues.getValue()); @@ -1297,20 +1160,18 @@ } }; -/// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant. -template -class FoldSplatConstants : public OpRewritePattern { +/// Pattern to fold a generic op with a splat constant. +class FoldSplatConstants : public OpRewritePattern { public: FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), controlFn(fun) {} + : OpRewritePattern(context, benefit), controlFn(fun) {} - LogicalResult matchAndRewrite(LinalgOpTy op, + LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!op.hasTensorSemantics()) + if (!genericOp.hasTensorSemantics()) return failure(); - LinalgOp linalgOp = cast(op.getOperation()); - for (auto operand : llvm::enumerate(linalgOp.getInputOpOperands())) { + for (auto operand : llvm::enumerate(genericOp.getInputOpOperands())) { Operation *def = operand.value().get().getDefiningOp(); DenseElementsAttr constantAttr; if (!def || @@ -1320,49 +1181,46 @@ continue; // The indexing_maps for the operands of the fused operation are same as - // those for the operands of the linalgOp without the indexing map at + // those for the operands of the genericOp without the indexing map at // operand.index() SmallVector fusedIndexMaps = llvm::to_vector<4>( - linalgOp.indexing_maps().getAsValueRange()); + genericOp.indexing_maps().getAsValueRange()); fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index())); // Check if the operation shapes to loops map is computable. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { return rewriter.notifyMatchFailure( - linalgOp, "fused op loop bound computation failed"); + genericOp, "fused op loop bound computation failed"); } - // The operands list is same as the linalgOp with the argument for + // The operands list is same as the genericOp with the argument for // constant index dropped. - SmallVector fusedOperands(linalgOp.getInputs()); + SmallVector fusedOperands(genericOp.getInputs()); fusedOperands.erase(std::next(fusedOperands.begin(), operand.index())); // Create a constant scalar value from the splat constant. Value scalarConstant = rewriter.create( def->getLoc(), constantAttr.getSplatValue()); - LinalgOp fusedOp = createLinalgOpOfSameType( - linalgOp, rewriter, rewriter.getUnknownLoc(), - linalgOp->getResultTypes(), + auto fusedOp = rewriter.create( + rewriter.getUnknownLoc(), genericOp->getResultTypes(), /*inputs=*/fusedOperands, - /*outputs=*/linalgOp.getOutputs(), + /*outputs=*/genericOp.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps), - linalgOp.iterator_types(), + genericOp.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr); // Map the block argument corresponding to the replaced argument with the // scalar constant. - Region &linalgOpRegion = linalgOp->getRegion(0); - Block &entryBlock = *linalgOpRegion.begin(); - unsigned argIndex = entryBlock.getNumArguments() - - linalgOp.getNumShapedOperands() + operand.index(); + Region ®ion = genericOp->getRegion(0); + Block &entryBlock = *region.begin(); BlockAndValueMapping mapping; - mapping.map(entryBlock.getArgument(argIndex), scalarConstant); + mapping.map(entryBlock.getArgument(operand.index()), scalarConstant); Region &fusedRegion = fusedOp->getRegion(0); - rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion, - fusedRegion.begin(), mapping); - rewriter.replaceOp(linalgOp, fusedOp->getResults()); + rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), + mapping); + rewriter.replaceOp(genericOp, fusedOp->getResults()); return success(); } return failure(); @@ -1373,20 +1231,15 @@ }; } // namespace -static Optional> +static Optional> fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand, + GenericOp producer, const ControlElementwiseOpsFusionFn &controlFn) { - Operation *producer = consumerOpOperand.get().getDefiningOp(); - if (!producer || producer->getNumResults() != 1) - return llvm::None; - - // Fuse when consumer is GenericOp or IndexedGenericOp. - if (!isa(consumerOpOperand.getOwner()) || - !isa(producer)) + if (producer->getNumResults() != 1) return llvm::None; - return fuseElementwiseOpsImpl(cast(producer), consumerOpOperand, - controlFn, rewriter); + return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, + rewriter); } bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, @@ -1398,25 +1251,24 @@ namespace { /// Patterns to fuse a generic op, with the producer of its operands. -template -class FuseElementwiseOps : public OpRewritePattern { +class FuseElementwiseOps : public OpRewritePattern { public: FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), controlFn(fun) {} + : OpRewritePattern(context, benefit), controlFn(fun) {} - LogicalResult matchAndRewrite(LinalgOpTy op, + LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. - for (OpOperand &opOperand : op.getShapedOpOperands()) { - LinalgOp producerOp = - dyn_cast_or_null(opOperand.get().getDefiningOp()); - if (!producerOp || !producerOp.hasTensorSemantics()) + for (OpOperand &opOperand : genericOp.getShapedOpOperands()) { + auto producer = + dyn_cast_or_null(opOperand.get().getDefiningOp()); + if (!producer || !producer.hasTensorSemantics()) continue; - Optional> fusedOpResults = - fuseElementwiseOps(rewriter, opOperand, controlFn); + Optional> fusedOpResults = + fuseElementwiseOps(rewriter, opOperand, producer, controlFn); if (fusedOpResults) { - rewriter.replaceOp(op, *fusedOpResults); + rewriter.replaceOp(genericOp, *fusedOpResults); return success(); } } @@ -1445,8 +1297,7 @@ } }; -/// Pass to test folding of reshape op with generic/indexed_generic ops by -/// linearization. +/// Pass to test folding of reshape ops with generic ops by linearization. struct FoldReshapeOpsByLinearizationPass : public LinalgFoldReshapeOpsByLinearizationBase< FoldReshapeOpsByLinearizationPass> { @@ -1462,16 +1313,14 @@ void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { - patterns.add, - FoldProducerReshapeOpByLinearization, + patterns.add, FoldConsumerReshapeOpByLinearization>( patterns.getContext()); } void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { - patterns.add, - FoldProducerReshapeOpByLinearization, + patterns.add, FoldConsumerReshapeOpByLinearization>( patterns.getContext()); } @@ -1480,18 +1329,15 @@ RewritePatternSet &patterns, ControlElementwiseOpsFusionFn controlFoldingReshapes) { patterns.add(patterns.getContext()); - patterns.add, - FoldWithProducerReshapeOpByExpansion>( - patterns.getContext(), controlFoldingReshapes); + patterns.add(patterns.getContext(), + controlFoldingReshapes); } void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { auto *context = patterns.getContext(); - patterns - .add, FuseElementwiseOps, - FoldSplatConstants, FoldSplatConstants>( - context, options.controlElementwiseOpsFusionFn); + patterns.add( + context, options.controlElementwiseOpsFusionFn); populateFoldReshapeOpsByExpansionPatterns(patterns, options.controlFoldingReshapesFn); AffineApplyOp::getCanonicalizationPatterns(patterns, context); @@ -1502,8 +1348,7 @@ void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add, - PushExpandingReshape>(context); + patterns.add(context); } std::unique_ptr mlir::createLinalgFusionOfTensorOpsPass() { diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -207,34 +207,38 @@ #map0 = affine_map<(d0, d1, d2) -> (d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) - -> tensor<5x?x?xf32> +func @generic_indexed_op_constant_fusion(%arg0 : tensor<5x?x?xindex>) + -> tensor<5x?x?xindex> { %c0 = constant 0 : index %c1 = constant 1 : index %c2 = constant 2 : index - %cst = constant dense<42.0> : tensor<5xf32> - %0 = memref.dim %arg0, %c1 : tensor<5x?x?xf32> - %1 = memref.dim %arg0, %c2 : tensor<5x?x?xf32> - %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32> - %3 = linalg.indexed_generic { + %cst = constant dense<42> : tensor<5xindex> + %0 = memref.dim %arg0, %c1 : tensor<5x?x?xindex> + %1 = memref.dim %arg0, %c2 : tensor<5x?x?xindex> + %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xindex> + %3 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} - ins(%cst, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>) - outs(%2 : tensor<5x?x?xf32>) { - ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32, %arg6 : f32): - %4 = mulf %arg4, %arg5 : f32 - linalg.yield %4 : f32 - } -> tensor<5x?x?xf32> - return %3 : tensor<5x?x?xf32> + ins(%cst, %arg0 : tensor<5xindex>, tensor<5x?x?xindex>) + outs(%2 : tensor<5x?x?xindex>) { + ^bb0(%arg1: index, %arg2 : index, %arg3 : index): + %idx = linalg.index 0 : index + %4 = subi %arg1, %arg2 : index + %5 = addi %4, %idx : index + linalg.yield %5 : index + } -> tensor<5x?x?xindex> + return %3 : tensor<5x?x?xindex> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func @indexed_generic_op_constant_fusion -// CHECK: %[[CST:.*]] = constant {{.*}} : f32 +// CHECK-LABEL: func @generic_indexed_op_constant_fusion +// CHECK: %[[CST:.*]] = constant 42 : index // CHECK: linalg.generic // CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32) -// CHECK: mulf %[[CST]], %[[ARG4]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: index, %{{.*}}: index) +// CHECK: %[[IDX:.+]] = linalg.index 0 : index +// CHECK: %[[VAL1:.+]] = subi %[[CST]], %[[ARG2]] +// CHECK: %{{.*}} = addi %[[VAL1]], %[[IDX]] // ----- @@ -272,84 +276,38 @@ #map0 = affine_map<(d0, d1, d2) -> ()> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func @indexed_generic_op_zero_dim_constant_fusion - (%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32> +func @generic_indexed_op_zero_dim_constant_fusion + (%arg0 : tensor<5x?x?xindex>) -> tensor<5x?x?xindex> { %c0 = constant 0 : index %c1 = constant 1 : index %c2 = constant 2 : index - %cst = constant dense<42.0> : tensor - %0 = memref.dim %arg0, %c1 : tensor<5x?x?xf32> - %1 = memref.dim %arg0, %c2 : tensor<5x?x?xf32> - %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32> - %3 = linalg.indexed_generic { + %cst = constant dense<42> : tensor + %0 = memref.dim %arg0, %c1 : tensor<5x?x?xindex> + %1 = memref.dim %arg0, %c2 : tensor<5x?x?xindex> + %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xindex> + %3 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} - ins(%cst, %arg0 : tensor, tensor<5x?x?xf32>) - outs(%2 : tensor<5x?x?xf32>) { - ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32, %arg6: f32): - %4 = mulf %arg4, %arg5 : f32 - linalg.yield %4 : f32 - } -> tensor<5x?x?xf32> - return %3 : tensor<5x?x?xf32> + ins(%cst, %arg0 : tensor, tensor<5x?x?xindex>) + outs(%2 : tensor<5x?x?xindex>) { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + %idx = linalg.index 0 : index + %4 = subi %arg1, %arg2 : index + %5 = addi %4, %idx : index + linalg.yield %5 : index + } -> tensor<5x?x?xindex> + return %3 : tensor<5x?x?xindex> } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion -// CHECK: %[[CST:.*]] = constant {{.*}} : f32 +// CHECK-LABEL: func @generic_indexed_op_zero_dim_constant_fusion +// CHECK: %[[CST:.*]] = constant 42 : index // CHECK: linalg.generic // CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32) -// CHECK: mulf %[[CST]], %[[ARG4]] - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -func @generic_op_indexed_generic_op_fusion(%arg0: tensor, - %arg1: tensor) -> tensor { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = memref.dim %arg0, %c0 : tensor - %1 = memref.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor - %3 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"] } - ins(%arg0, %arg1 : tensor, tensor) - outs(%2 : tensor) { - ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors - %10 = addi %arg2, %arg3 : i32 - linalg.yield %10 : i32 - } -> tensor - %4 = linalg.indexed_generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"] } - ins(%3 : tensor) - outs(%2 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors - %5 = index_cast %arg2 : index to i32 - %6 = index_cast %arg3 : index to i32 - %7 = addi %arg4, %5 : i32 - %8 = subi %7, %6 : i32 - linalg.yield %8 : i32 - } -> tensor - return %4 : tensor -} -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func @generic_op_indexed_generic_op_fusion -// CHECK-NOT: linalg.indexed_generic -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]] -// CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32 -// CHECK: %[[ARG0:.+]] = linalg.index 0 : index -// CHECK: %[[ARG1:.+]] = linalg.index 1 : index -// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ARG3]] : i32 -// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32 -// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32 -// CHECK: %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32 -// CHECK: %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32 -// CHECK: linalg.yield %[[VAL3]] : i32 +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: index, %{{.*}}: index) +// CHECK: %[[IDX:.+]] = linalg.index 0 : index +// CHECK: %[[VAL1:.+]] = subi %[[CST]], %[[ARG2]] +// CHECK: %{{.*}} = addi %[[VAL1]], %[[IDX]] // ----- @@ -405,56 +363,6 @@ // ----- -#map0 = affine_map<(d0, d1) -> (d0, d1)> -func @indexed_generic_op_generic_op_fusion(%arg0: tensor, - %arg1: tensor) -> tensor { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = memref.dim %arg0, %c0 : tensor - %1 = memref.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor - %3 = linalg.indexed_generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"] } - ins(%arg0 : tensor) - outs(%2 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors - %4 = index_cast %arg2 : index to i32 - %5 = index_cast %arg3 : index to i32 - %6 = addi %arg4, %4 : i32 - %7 = subi %6, %5 : i32 - linalg.yield %7 : i32 - } -> tensor - %4 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"] } - ins(%3, %arg1 : tensor, tensor) - outs(%2 : tensor) { - ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors - %10 = addi %arg2, %arg3 : i32 - linalg.yield %10 : i32 - } -> tensor - return %4 : tensor -} -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func @indexed_generic_op_generic_op_fusion -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]] -// CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32 -// CHECK: %[[ARG0:.+]] = linalg.index 0 : index -// CHECK: %[[ARG1:.+]] = linalg.index 1 : index -// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32 -// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32 -// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND]] : i32 -// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32 -// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG3]] : i32 -// CHECK: linalg.yield %[[VAL3]] : i32 -// CHECK-NOT: linalg.generic - -// ----- - #map0 = affine_map<(d0, d1) -> (d0, d1)> func @indexed_producer_consumer_fusion(%arg0: tensor) -> tensor { %c0 = constant 0 : index @@ -506,63 +414,7 @@ // ----- -// The indices of the first indexed_generic op are swapped after fusion. -#map0 = affine_map<(d0, d1) -> (d1, d0)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func @indexed_generic_op_fusion(%arg0: tensor) -> tensor { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = memref.dim %arg0, %c0 : tensor - %1 = memref.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor - %3 = linalg.indexed_generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"] } - ins(%arg0 : tensor) - outs(%2 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors - %4 = index_cast %arg2 : index to i32 - %5 = index_cast %arg3 : index to i32 - %6 = addi %arg4, %4 : i32 - %7 = subi %5, %6 : i32 - linalg.yield %7 : i32 - } -> tensor - %4= linalg.indexed_generic { - indexing_maps = [#map1, #map1], - iterator_types = ["parallel", "parallel"] } - ins(%3 : tensor) - outs(%2 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors - %5 = index_cast %arg2 : index to i32 - %6 = index_cast %arg3 : index to i32 - %7 = addi %arg4, %5 : i32 - %8 = subi %7, %6 : i32 - linalg.yield %8 : i32 - } -> tensor - return %4 : tensor -} -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func @indexed_generic_op_fusion -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] -// CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 -// CHECK: %[[ARG0:.+]] = linalg.index 0 : index -// CHECK: %[[ARG1:.+]] = linalg.index 1 : index -// CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[ARG1]] : index to i32 -// CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[ARG0]] : index to i32 -// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND1]] : i32 -// CHECK: %[[VAL2:.+]] = subi %[[SUB_OPERAND1]], %[[VAL1]] : i32 -// CHECK: %[[ADD_OPERAND2:.+]] = index_cast %[[ARG0]] : index to i32 -// CHECK: %[[SUB_OPERAND2:.+]] = index_cast %[[ARG1]] : index to i32 -// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32 -// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32 -// CHECK: linalg.yield %[[VAL4]] : i32 -// CHECK-NOT: linalg.generic - -// ----- - -// The indices of the first indexed_generic op are swapped after fusion. +// The indices of the first generic op are swapped after fusion. #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> func @indexed_producer_indexed_consumer_fusion(%arg0: tensor) @@ -625,48 +477,6 @@ // ----- -func @scalar_indexed_generic_fusion - (%arg0: tensor<5x1x1xf32>, %arg1 : tensor) -> tensor<10xf32> -{ - %c0 = constant 0 : index - %cst = constant dense<1.000000e+00> : tensor<10xf32> - %0 = linalg.init_tensor [] : tensor - %1 = linalg.indexed_generic - {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], - iterator_types = []} - ins(%arg1 : tensor) outs(%0 : tensor) { - ^bb0(%arg2: i32, %arg3: f32): // no predecessors - %3 = index_cast %arg2 : i32 to index - %4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32> - linalg.yield %4 : f32 - } -> tensor - %2 = linalg.init_tensor [10] : tensor<10xf32> - %3 = linalg.generic - {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%1, %cst : tensor, tensor<10xf32>) outs(%2 : tensor<10xf32>) { - ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors - %4 = mulf %arg2, %arg3 : f32 - linalg.yield %4 : f32 - } -> tensor<10xf32> - return %3 : tensor<10xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> -// CHECK: func @scalar_indexed_generic_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[ARG1]] : tensor) -// CHECK: tensor.extract %[[ARG0]] -// CHECK: linalg.yield -// CHECK return %[[T0]] - -// ----- - func @scalar_generic_fusion (%arg0: tensor<5x1x1xf32>, %arg1 : tensor) -> tensor<10xf32> { diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -164,56 +164,6 @@ // ----- -#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> -func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] : - tensor into tensor - %1 = linalg.indexed_generic { - indexing_maps = [#map0, #map1, #map1], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%0, %arg1 : tensor, tensor) - outs(%0 : tensor) { - ^bb0(%arg3 : index, %arg4 : index, %arg5 : index, %arg6: i32, %arg7: i32, %s: i32): - %1 = muli %arg6, %arg7 : i32 - %2 = index_cast %arg3 : index to i32 - %3 = addi %1, %2 : i32 - %4 = index_cast %arg4 : index to i32 - %5 = addi %3, %4 : i32 - %6 = index_cast %arg5 : index to i32 - %7 = addi %5, %6 : i32 - linalg.yield %7 : i32 - } -> tensor - return %1 : tensor -} - -// The generic op version of the test check for the op structure. Only -// checking the op body here. -// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)> -// CHECK: func @indexed_generic_op_reshape_producer_fusion -// CHECK: linalg.generic -// CHECK: ^{{.*}}( -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32, -// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32) -// CHECK: %[[ARG2:.+]] = linalg.index 0 : index -// CHECK: %[[ARG3:.+]] = linalg.index 1 : index -// CHECK: %[[ARG4:.+]] = linalg.index 2 : index -// CHECK: %[[ARG5:.+]] = linalg.index 3 : index -// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG3]], %[[ARG2]]) -// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]] -// CHECK: %[[T5:.+]] = index_cast %[[T3]] -// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] -// CHECK: %[[T7:.+]] = index_cast %[[ARG4]] -// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]] -// CHECK: %[[T9:.+]] = index_cast %[[ARG5]] -// CHECK: %[[T10:.+]] = addi %[[T8]], %[[T9]] -// CHECK: linalg.yield %[[T10]] - -// ----- - #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor, @@ -266,50 +216,6 @@ // ----- -#map0 = affine_map<(d0, d1) -> (d0, d1)> -func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.indexed_generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%arg0 : tensor) { - ^bb0(%arg3 : index, %arg4 : index, %arg5: i32, %arg6: i32, %s: i32): // no predecessors - %1 = muli %arg5, %arg6 : i32 - %2 = index_cast %arg3 : index to i32 - %3 = addi %1, %2 : i32 - %4 = index_cast %arg4 : index to i32 - %5 = addi %3, %4 : i32 - linalg.yield %5 : i32 - } -> tensor - %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] : - tensor into tensor - return %1 : tensor -} -// The generic op version of the test check for the op structure. Only -// checking the op body here. -// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)> -// CHECK: func @indexed_generic_op_reshape_consumer_fusion -// CHECK: linalg.generic -// CHECK: ^{{.*}}( -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32, -// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32) -// CHECK: %[[ARG2:.+]] = linalg.index 0 : index -// CHECK: %[[ARG3:.+]] = linalg.index 1 : index -// CHECK: %[[ARG4:.+]] = linalg.index 2 : index -// CHECK: %[[ARG5:.+]] = linalg.index 3 : index -// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG5]], %[[ARG4]], %[[ARG3]]) -// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]] -// CHECK: %[[T5:.+]] = index_cast %[[ARG2]] -// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] -// CHECK: %[[T7:.+]] = index_cast %[[T3]] -// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]] -// CHECK: linalg.yield %[[T8]] - -// ----- - #map0 = affine_map<(d0, d1) -> (d0, d1)> func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor, %arg1 : tensor) -> @@ -356,69 +262,6 @@ // ----- -func @reshape_as_consumer_permutation - (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>) - -> tensor<2x3x4x5x6x7xi32> { - %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32> - %c = linalg.indexed_generic { - indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d2, d1)>], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>) - outs(%shape : tensor<6x4x210xi32>) { - ^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32, %s: i32): - %1 = addi %arg3, %arg4 : i32 - %2 = index_cast %arg0 : index to i32 - %3 = addi %1, %2 : i32 - %4 = index_cast %arg1 : index to i32 - %5 = addi %3, %4 : i32 - %6 = index_cast %arg2 : index to i32 - %7 = addi %5, %6 : i32 - linalg.yield %7 : i32 - } -> tensor<6x4x210xi32> - %d = linalg.tensor_reshape %c [[0, 1], [2], [3, 4, 5]] - : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> - return %d : tensor<2x3x4x5x6x7xi32> -} -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> -// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> -// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> -// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> -// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)> -// CHECK: func @reshape_as_consumer_permutation -// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> -// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> -// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]] -// CHECK-SAME: [0, 1, 2], [3, 4], [5] -// CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]] -// CHECK-SAME: [0, 1, 2], [3] -// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7] -// CHECK: %[[T4:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]] -// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>) -// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>) -// CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32, -// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32) -// CHECK: %[[ARG2:.+]] = linalg.index 0 : index -// CHECK: %[[ARG3:.+]] = linalg.index 1 : index -// CHECK: %[[ARG4:.+]] = linalg.index 2 : index -// CHECK: %[[ARG5:.+]] = linalg.index 3 : index -// CHECK: %[[ARG6:.+]] = linalg.index 4 : index -// CHECK: %[[ARG7:.+]] = linalg.index 5 : index -// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[ARG3]], %[[ARG2]]) -// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[ARG6]], %[[ARG5]], %[[ARG4]]) -// CHECK-DAG: %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]] -// CHECK: %[[T8:.+]] = index_cast %[[T5]] -// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]] -// CHECK: %[[T10:.+]] = index_cast %[[T6]] -// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]] -// CHECK: %[[T12:.+]] = index_cast %[[ARG7]] -// CHECK: %[[T13:.+]] = addi %[[T11]], %[[T12]] - -// ----- - func @reshape_as_consumer_permutation (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>) -> tensor<2x3x4x5x6x7xi32> { @@ -487,59 +330,6 @@ // ----- -func @reshape_as_producer_projected_permutation( - %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32> -{ - %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]] - : tensor<33x8x?xi32> into tensor<264x?xi32> - %1 = linalg.indexed_generic - {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1, d2)>], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%0 : tensor<264x?xi32>) - outs(%shape : tensor<264x?x4xi32>) { - ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32, %s: i32): // no predecessors - %2 = index_cast %arg1 : index to i32 - %3 = addi %arg4, %2 : i32 - %4 = index_cast %arg2 : index to i32 - %5 = addi %3, %4 : i32 - %6 = index_cast %arg3 : index to i32 - %7 = addi %5, %6 : i32 - linalg.yield %7 : i32 - } -> tensor<264x?x4xi32> - return %1 : tensor<264x?x4xi32> -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)> -// CHECK: @reshape_as_producer_projected_permutation -// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: ins(%[[ARG0]] : tensor<33x8x?xi32>) -// CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32, -// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: i32) -// CHECK: %[[ARG1:.+]] = linalg.index 0 : index -// CHECK: %[[ARG2:.+]] = linalg.index 1 : index -// CHECK: %[[ARG3:.+]] = linalg.index 2 : index -// CHECK: %[[ARG4:.+]] = linalg.index 3 : index -// CHECK: %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG2]], %[[ARG1]]) -// CHECK: %[[T1:.+]] = index_cast %[[T0]] : index to i32 -// CHECK: %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32 -// CHECK: %[[T3:.+]] = index_cast %[[ARG3]] : index to i32 -// CHECK: %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32 -// CHECK: %[[T5:.+]] = index_cast %[[ARG4]] : index to i32 -// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32 -// CHECK: linalg.yield %[[T6]] : i32 -// CHECK: %[[RES2:.+]] = linalg.tensor_reshape %[[RES]] -// CHECK-SAME: [0, 1], [2], [3] -// CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32> -// CHECK: return %[[RES2]] : tensor<264x?x4xi32> - -// ----- - func @reshape_as_producer_projected_permutation( %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32> { diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir @@ -1,75 +1,18 @@ // RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @generic_op_reshape_producer_fusion(%arg0 : tensor, - %arg1 : tensor) -> tensor { - %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] : - tensor into tensor - %1 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%0, %arg1 : tensor, tensor) - outs(%0 : tensor) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors - %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - return %1 : tensor -} -// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @generic_op_reshape_producer_fusion -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] -// CHECK-SAME: [0], [1, 2], [3] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]] -// CHECK-SAME: ins(%[[ARG0]], %{{.+}} : tensor, tensor) -// CHECK-SAME: outs(%[[T0]] : tensor) - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @generic_op_reshape_consumer_fusion(%arg0 : tensor, - %arg1 : tensor) -> tensor { - %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%arg0 : tensor){ - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors - %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] : - tensor into tensor - return %1 : tensor -} - -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> -// CHECK: func @generic_op_reshape_consumer_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]] -// CHECK-SAME: outs(%[[T0]] : tensor) - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor) +func @generic_op_reshape_producer_fusion(%arg0 : tensor) -> tensor { %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] : tensor into tensor - %1 = linalg.indexed_generic { + %1 = linalg.generic { indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"] } ins(%0 : tensor) outs(%0 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32, %arg7 : i32): // no predecessors - %2 = index_cast %arg2 : index to i32 + ^bb0(%arg6: i32, %arg7 : i32): // no predecessors + %idx = linalg.index 0 : index + %2 = index_cast %idx : index to i32 %3 = addi %arg6, %2 : i32 linalg.yield %3 : i32 } -> tensor @@ -77,26 +20,29 @@ } // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> // CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @indexed_generic_op_reshape_producer_fusion +// CHECK: func @generic_op_reshape_producer_fusion // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] // CHECK-SAME: [0], [1, 2], [3] -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]] // CHECK-SAME: ins(%[[ARG0]] : tensor) // CHECK-SAME: outs(%[[T0]] : tensor) +// CHECK: %[[IDX:.+]] = linalg.index 0 : index +// CHECK-NEXT: %[[IDX_CASTED:.+]] = index_cast %[[IDX]] : index to i32 // ----- #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor) +func @generic_op_reshape_consumer_fusion(%arg0 : tensor) -> tensor { - %0 = linalg.indexed_generic { + %0 = linalg.generic { indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"] } ins(%arg0 : tensor) outs(%arg0 : tensor) { - ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32, %arg7: i32): // no predecessors - %2 = index_cast %arg2 : index to i32 + ^bb0(%arg6: i32, %arg7: i32): // no predecessors + %idx = linalg.index 0 : index + %2 = index_cast %idx : index to i32 %3 = addi %arg6, %2 : i32 linalg.yield %3 : i32 } -> tensor @@ -106,13 +52,15 @@ } // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> -// CHECK: func @indexed_generic_op_reshape_consumer_fusion +// CHECK: func @generic_op_reshape_consumer_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] // CHECK-SAME: [0], [1, 2, 3] -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] // CHECK-SAME: outs(%[[T0]] : tensor) +// CHECK: %[[IDX:.+]] = linalg.index 0 : index +// CHECK-NEXT: %[[IDX_CASTED:.+]] = index_cast %[[IDX]] : index to i32 // CHECK-NOT: linalg.tensor_reshape // -----