diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -82,8 +82,8 @@ if (!owner) return llvm::None; if (OpOperand *operand = opView.dyn_cast()) - return owner.getTiedIndexingMap(operand); - return owner.getTiedIndexingMap(owner.getOutputOperand( + return owner.getMatchingIndexingMap(operand); + return owner.getMatchingIndexingMap(owner.getOutputOperand( opView.get().cast().getResultNumber())); } // Return the operand number if the `opView` is an OpOperand *. Otherwise diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -377,7 +377,7 @@ Return the block argument for an `opOperand`. }], /*retTy=*/"BlockArgument", - /*methodName=*/"getTiedBlockArgument", + /*methodName=*/"getMatchingBlockArgument", /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -390,7 +390,7 @@ Return the operand for a `blockArgument`. }], /*retTy=*/"OpOperand *", - /*methodName=*/"getTiedOpOperand", + /*methodName=*/"getMatchingOpOperand", /*args=*/(ins "BlockArgument":$blockArgument), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -404,7 +404,7 @@ Return the input or output indexing map for `opOperand`. }], /*retTy=*/"AffineMap", - /*methodName=*/"getTiedIndexingMap", + /*methodName=*/"getMatchingIndexingMap", /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -419,7 +419,7 @@ Return the indexing map for a `result`. }], /*retTy=*/"AffineMap", - /*methodName=*/"getTiedIndexingMapForResult", + /*methodName=*/"getIndexingMapMatchingResult", /*args=*/(ins "OpResult":$result), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -442,7 +442,7 @@ `opOperand`. }], /*retTy=*/"OpOperand *", - /*methodName=*/"getTiedYieldValue", + /*methodName=*/"getMatchingYieldValue", /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -34,7 +34,7 @@ for (auto *opOperand : linalgOp.getInputAndOutputOperands()) { if (llvm::is_contained(droppedOperands, opOperand)) continue; - indexingMaps.push_back(linalgOp.getTiedIndexingMap(opOperand)); + indexingMaps.push_back(linalgOp.getMatchingIndexingMap(opOperand)); } return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); } @@ -658,7 +658,7 @@ << linalgOp.getNumInputsAndOutputs() << ")"; for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); // Symbols disallowed. if (indexingMap.getNumSymbols() != 0) @@ -696,7 +696,7 @@ for (int64_t &range : endLoopRangeValues) range -= 1; for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); SmallVector startIndices = indexingMap.compose(startLoopRangeValues); SmallVector endIndices = diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -945,7 +945,7 @@ // Check if this operand is a duplicate. AffineMap indexingMap = - genericOp.getTiedIndexingMap(inputOpOperand.value()); + genericOp.getMatchingIndexingMap(inputOpOperand.value()); auto it = dedupedInputs.find( std::make_pair(inputOpOperand.value()->get(), indexingMap)); if (it != dedupedInputs.end()) { @@ -984,7 +984,7 @@ origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); newOutputOperands.push_back(outputOpOperand.value()->get()); newIndexingMaps.push_back( - genericOp.getTiedIndexingMap(outputOpOperand.value())); + genericOp.getMatchingIndexingMap(outputOpOperand.value())); } } else { // Output argument can be dropped if the result has @@ -997,7 +997,7 @@ llvm::enumerate(genericOp.getOutputOperands())) { Value result = genericOp.getResult(outputOpOperand.index()); AffineMap indexingMap = - genericOp.getTiedIndexingMap(outputOpOperand.value()); + genericOp.getMatchingIndexingMap(outputOpOperand.value()); auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap, yieldOp->getOperand(outputOpOperand.index())); @@ -1033,7 +1033,7 @@ dedupedOutpts[key] = newOutputOperands.size(); newOutputOperands.push_back(outputOpOperand.value()->get()); newIndexingMaps.push_back( - genericOp.getTiedIndexingMap(outputOpOperand.value())); + genericOp.getMatchingIndexingMap(outputOpOperand.value())); } } @@ -1957,7 +1957,7 @@ continue; Value src = opOperand->get(); auto sourceType = src.getType().cast(); - auto sourceMap = linalgOp.getTiedIndexingMap(opOperand); + auto sourceMap = linalgOp.getMatchingIndexingMap(opOperand); // Get the `sourceShape` of the `sourceType`. If the operand is a result of // `tensor.cast` operation and source of the cast operation has a static @@ -2005,7 +2005,7 @@ return; } ArrayRef sourceShape = sourceType.getShape(); - AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand); + AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand); SmallVector newShape; // If operand is updated with new shape, `newOperandNeeded` will be // true. diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -81,7 +81,7 @@ } OpOperand *outOperand = linalgOp.getOutputOperand(0); - AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand); + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand); if (!indexingMap.isProjectedPermutation()) { return rewriter.notifyMatchFailure( sliceOp, "expected a projected permutation for output"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -180,7 +180,7 @@ OpResult result = genericOp.getResult(*resultNumber).cast(); newResultTypes.push_back(result.getType()); peeledGenericOpIndexingMaps.push_back( - genericOp.getTiedIndexingMapForResult(result)); + genericOp.getIndexingMapMatchingResult(result)); continue; } @@ -227,15 +227,16 @@ /// as those used for the new results of the peeledGenericOp. auto indexingMaps = llvm::to_vector( llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) { - return genericOp.getTiedIndexingMap(operand); + return genericOp.getMatchingIndexingMap(operand); })); for (auto resultNum : llvm::seq(origNumResults, peeledGenericOpNumResults)) { OpResult result = peeledGenericOp.getResult(resultNum).cast(); - indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result)); + indexingMaps.push_back( + peeledGenericOp.getIndexingMapMatchingResult(result)); } for (OpOperand *outOperand : genericOp.getOutputOperands()) - indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand)); + indexingMaps.push_back(genericOp.getMatchingIndexingMap(outOperand)); auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); return rewriter.create( @@ -263,7 +264,7 @@ } if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { - return !genericOp.getTiedIndexingMap(outOperand).isPermutation(); + return !genericOp.getMatchingIndexingMap(outOperand).isPermutation(); })) { return rewriter.notifyMatchFailure( genericOp, "unhandled decomposition of generic op with out operand not " diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -245,7 +245,7 @@ static llvm::Optional replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, MLIRContext *context) { - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); ArrayRef shape = genericOp.getShape(opOperand); ArrayRef exprs = indexingMap.getResults(); SmallVector reassociations; @@ -390,7 +390,7 @@ // type, indexing map, and create a set of mappings representing an // identity matrix. newInputOutputTypes.push_back(opOperand->get().getType()); - newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); + newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(opOperand)); int64_t origRank = genericOp.getRank(opOperand); auto maps = llvm::to_vector<8>(llvm::map_range( llvm::seq(0, origRank), [&](int64_t dim) -> Attribute { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -59,7 +59,7 @@ LinalgOp producer = cast(producerOpOperand->getOwner()); // argMap is a map from producer loop -> producer arg tensor index. - AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); + AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand); // Compose argMap with invProducerResultIndexMap to get a map from // producer result tensor index -> producer arg tensor index. @@ -95,14 +95,14 @@ // Get the consumer index map. The number of results of the consumer index // map must match the number of loops of the producer. - AffineMap consumerIndexMap = consumer.getTiedIndexingMap(fusedOperand); + AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand); if (consumerIndexMap.getNumResults() != producer.getNumLoops()) return false; // Finally the index_map for the result must be invertible. For now just // verify it is a permutation. AffineMap producerResultIndexMap = - producer.getTiedIndexingMap(producer.getOutputOperand(0)); + producer.getMatchingIndexingMap(producer.getOutputOperand(0)); if (!producerResultIndexMap.isPermutation()) return false; @@ -288,17 +288,17 @@ assert(it != consumerInputs.end() && "expected to find the consumer operand"); for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { fusedInputOperands.push_back(opOperand->get()); - fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); + fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); } // 4. Splice in producer's input operands/maps. AffineMap producerResultIndexMap = - producer.getTiedIndexingMapForResult(producerResult); + producer.getIndexingMapMatchingResult(producerResult); for (OpOperand *opOperand : producer.getInputOperands()) { fusedInputOperands.push_back(opOperand->get()); // Compute indexing maps for the producer args in the fused operation. AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( opOperand, producerResultIndexMap, - consumer.getTiedIndexingMap(fusedOperand)); + consumer.getMatchingIndexingMap(fusedOperand)); fusedIndexMaps.push_back(map); } // 5. Remaining consumer's input operands/maps (drop past index @@ -306,7 +306,7 @@ for (OpOperand *opOperand : llvm::make_range(std::next(it), consumerInputs.end())) { fusedInputOperands.push_back(opOperand->get()); - fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); + fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); } // 6. Collect all of the producer outputs. @@ -314,7 +314,7 @@ fusedOutputOperands.push_back(opOperand->get()); AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( opOperand, producerResultIndexMap, - consumer.getTiedIndexingMap(fusedOperand)); + consumer.getMatchingIndexingMap(fusedOperand)); fusedIndexMaps.push_back(map); fusedResultTypes.push_back(opOperand->get().getType()); } @@ -322,7 +322,7 @@ // 7. All of consumer's output operands (skip operands: added by the builder). for (OpOperand *opOperand : consumer.getOutputOperands()) { fusedOutputOperands.push_back(opOperand->get()); - fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); + fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); fusedResultTypes.push_back(opOperand->get().getType()); } @@ -344,7 +344,8 @@ // Construct an AffineMap from consumer loops to producer loops. // consumer loop -> tensor index - AffineMap consumerResultIndexMap = consumer.getTiedIndexingMap(fusedOperand); + AffineMap consumerResultIndexMap = + consumer.getMatchingIndexingMap(fusedOperand); // tensor index -> producer loop AffineMap invProducerResultIndexMap = inversePermutation(producerResultIndexMap); @@ -466,7 +467,7 @@ .getValue() .isProjectedPermutation(); }) && - genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && + genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > 0 && llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) { return it == getParallelIteratorTypeName(); }); @@ -517,7 +518,7 @@ PatternRewriter &rewriter) { if (reassociationMaps.empty()) return failure(); - AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); + AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); SmallVector originalLoopRange = linalgOp.getStaticLoopRanges(); originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); @@ -727,7 +728,7 @@ continue; } if (genericOp.isInputTensor(opOperand)) { - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); auto opOperandType = opOperand->get().getType().cast(); RankedTensorType expandedOperandType = getExpandedType(opOperandType, indexingMap, expansionInfo); @@ -755,7 +756,7 @@ Location loc = genericOp.getLoc(); SmallVector outputs; for (OpOperand *opOperand : genericOp.getOutputOperands()) { - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); auto opOperandType = opOperand->get().getType().cast(); RankedTensorType expandedOutputType = getExpandedType(opOperandType, indexingMap, expansionInfo); @@ -802,7 +803,7 @@ if (resultTypes[resultNumber] != opResult.getType()) { SmallVector reassociation = getReassociationForExpansion( - genericOp.getTiedIndexingMap( + genericOp.getMatchingIndexingMap( genericOp.getOutputOperand(resultNumber)), expansionInfo); resultVals.push_back(rewriter.create( @@ -1063,7 +1064,7 @@ } llvm::SmallDenseSet processedIterationDims; - AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand); + AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand); auto iteratorTypes = genericOp.getIteratorTypes().getValue(); SmallVector iterationSpaceReassociation; for (ReassociationIndicesRef foldedRangeDims : reassociation) { @@ -1312,7 +1313,7 @@ OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder) { - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); SmallVector operandReassociation = getOperandReassociation(indexingMap, collapsingInfo); @@ -1470,7 +1471,7 @@ auto collapsedOpResultType = collapsedOpResult.getType().cast(); if (collapsedOpResultType.getRank() != originalResultType.getRank()) { AffineMap indexingMap = - genericOp.getTiedIndexingMapForResult(originalResult.value()); + genericOp.getIndexingMapMatchingResult(originalResult.value()); SmallVector reassociation = getOperandReassociation(indexingMap, collapsingInfo); Value result = rewriter.create( @@ -1594,12 +1595,14 @@ if (inputOperand == opOperand) continue; Value inputValue = inputOperand->get(); - fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); + fusedIndexMaps.push_back( + genericOp.getMatchingIndexingMap(inputOperand)); fusedOperands.push_back(inputValue); fusedLocs.push_back(inputValue.getLoc()); } for (OpOperand *outputOperand : genericOp.getOutputOperands()) - fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); + fusedIndexMaps.push_back( + genericOp.getMatchingIndexingMap(outputOperand)); // Check if the operation shapes to loops map is computable. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -80,7 +80,7 @@ opOperand->get().getDefiningOp())) continue; - AffineMap map = op.getTiedIndexingMap(opOperand); + AffineMap map = op.getMatchingIndexingMap(opOperand); LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: " << opOperand->getOperandNumber() << "\n"); LLVM_DEBUG(llvm::dbgs() @@ -442,7 +442,7 @@ OpOperand *opOperand = producerOp.getOutputOperand(producerOpResult.getResultNumber()); LinalgOp fusedProducer = - fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand), + fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand), consumerOpOperand); // Replace use. diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -38,7 +38,7 @@ ArrayRef tiledLoopDims) { // Get the consumer operand indexing map. LinalgOp consumerOp = consumerOperand->getOwner(); - AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand); + AffineMap indexingMap = consumerOp.getMatchingIndexingMap(consumerOperand); // Search the slice dimensions tiled by a tile loop dimension. DenseSet tiledSliceDimIndices; @@ -68,7 +68,7 @@ // Get the indexing map of the `producerOp` output operand that matches // ´producerResult´. - AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( + AffineMap producerIndexingMap = producerOp.getMatchingIndexingMap( producerOp.getOutputOperand(producerResult.getResultNumber())); // Keep only the tiled result slice dimensions of `producerIndexingMap`. @@ -351,7 +351,7 @@ // Check `consumerOpOperand` is not shape-only to avoid fusion if the data is // not used by the `consumerOp` computation. - BlockArgument bbArg = consumerOp.getTiedBlockArgument(consumerOpOperand); + BlockArgument bbArg = consumerOp.getMatchingBlockArgument(consumerOpOperand); if (bbArg.getUses().empty()) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -42,7 +42,7 @@ SmallVector newIndexingMaps; SmallVector newOperands; for (OpOperand *opOperand : genericOp.getInputOperands()) { - AffineMap map = genericOp.getTiedIndexingMap(opOperand); + AffineMap map = genericOp.getMatchingIndexingMap(opOperand); if (genericOp.isInputTensor(opOperand) && map.isConstant()) { scalarOperands.emplace_back(opOperand->getOperandNumber()); } else { @@ -55,7 +55,7 @@ return failure(); for (OpOperand *opOperand : genericOp.getOutputOperands()) - newIndexingMaps.emplace_back(genericOp.getTiedIndexingMap(opOperand)); + newIndexingMaps.emplace_back(genericOp.getMatchingIndexingMap(opOperand)); Location loc = genericOp->getLoc(); SmallVector outputOperands = genericOp.getOutputOperands(); @@ -71,7 +71,7 @@ for (auto idx : llvm::reverse(scalarOperands)) { OpOperand *opOperand = genericOp.getInputOperand(idx); - AffineMap map = genericOp.getTiedIndexingMap(opOperand); + AffineMap map = genericOp.getMatchingIndexingMap(opOperand); SmallVector indices = map.getConstantResults(); SmallVector indicesValues; for (auto idx : indices) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -68,7 +68,7 @@ // 2. Compute the interchanged indexing maps. SmallVector newIndexingMaps; for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - AffineMap m = genericOp.getTiedIndexingMap(opOperand); + AffineMap m = genericOp.getMatchingIndexingMap(opOperand); if (!permutationMap.isEmpty()) m = m.compose(permutationMap); newIndexingMaps.push_back(m); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -144,14 +144,14 @@ continue; } auto indexing = makeCanonicalAffineApplies( - b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims); + b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims); indexedValues.push_back( b.create(loc, inputOperand->get(), indexing)); } // 1.b. Emit load from output views. for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { SmallVector indexing = makeCanonicalAffineApplies( - b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims); + b, loc, linalgOp.getMatchingIndexingMap(outputOperand), allIvsPlusDims); indexedValues.push_back( b.create(loc, outputOperand->get(), indexing)); } @@ -163,7 +163,8 @@ SmallVector outputBuffers; for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) { indexing.push_back(makeCanonicalAffineApplies( - b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims)); + b, loc, linalgOp.getMatchingIndexingMap(outputOperand), + allIvsPlusDims)); outputBuffers.push_back(outputOperand->get()); } inlineRegionAndEmitStore(b, loc, linalgOp, indexedValues, diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -117,7 +117,7 @@ SmallVector newMaps; // Calculate the new shapes and indexing maps of the input operands. for (OpOperand *operand : op.getInputOperands()) { - AffineMap map = op.getTiedIndexingMap(operand); + AffineMap map = op.getMatchingIndexingMap(operand); SmallVector newShape; SmallVector exprs; SmallVector reassociation; @@ -171,7 +171,7 @@ // Calculate the new output map and shape, we insert the new dimension based // on the index returned by `controlSplitReductionFn`. SmallVector newOutputShape; - AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); + AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getOutputOperand(0)); ArrayRef oldShape = op.getShape(op.getOutputOperand(0)); SmallVector outputExpr; for (unsigned idx : @@ -273,7 +273,7 @@ int64_t reductionRatio) { auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext()); - AffineMap map = op.getTiedIndexingMap(&opOperand); + AffineMap map = op.getMatchingIndexingMap(&opOperand); AffineMap idMap = AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); @@ -286,7 +286,7 @@ static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t size) { auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); - AffineMap map = op.getTiedIndexingMap(&opOperand); + AffineMap map = op.getMatchingIndexingMap(&opOperand); AffineMap idMap = AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -62,7 +62,7 @@ Value toStore = map.lookupOrDefault(operand.value()); OpOperand *storeInto = linalgOp.getOutputOperand(operand.index()); auto indices = getIndicesForAccess( - b, loc, linalgOp.getTiedIndexingMap(storeInto), ivs); + b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); b.create(loc, toStore, linalgOp.getOutputOperand(operand.index())->get(), indices); @@ -162,10 +162,10 @@ })); OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); - SliceParameters sliceParams = - computeSliceParameters(b, loc, outOperand->get(), sizes, - linalgOp.getTiedIndexingMap(outOperand), offsets, - /*ubs*/ {}, subShapeSizes, true); + SliceParameters sliceParams = computeSliceParameters( + b, loc, outOperand->get(), sizes, + linalgOp.getMatchingIndexingMap(outOperand), offsets, + /*ubs*/ {}, subShapeSizes, true); resultOffsets = sliceParams.offsets; resultSizes = sliceParams.sizes; return success(); @@ -182,7 +182,7 @@ // map the offsets and sizes from the result to iteration space tiles // (filling in full extent for dimensions not used to access the result). AffineMap indexingMap = - linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber)); + linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber)); if (!indexingMap.isProjectedPermutation()) { return op->emitOpError( "unhandled tiled implementation generation when result is not " @@ -238,7 +238,7 @@ continue; } SmallVector indices = getIndicesForAccess( - builder, linalgOpLoc, linalgOp.getTiedIndexingMap(operand), ivs); + builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(operand), ivs); Value load = builder.create(linalgOpLoc, operand->get(), indices); indexedValues.push_back(load); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -172,7 +172,7 @@ OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, ArrayRef paddingDimensions, ArrayRef paddingValues, ArrayRef packPaddings) { - AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand); + AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand); ArrayRef shape = opToPad.getShape(opOperand); // Collect the shape dimension that are a function of the `paddingDimensions`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -215,7 +215,7 @@ if (vectorType.getRank() > 0) { // 0-d case is still special: do not invert the reindexing map. AffineMap map = - reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)); + reindexIndexingMap(linalgOp.getMatchingIndexingMap(outputOperand)); SmallVector transposeShape = applyPermutationMap(inversePermutation(map), vectorType.getShape()); assert(!transposeShape.empty() && "unexpected empty transpose shape"); @@ -479,12 +479,12 @@ // } else { if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { map = inverseAndBroadcastProjectedPermutation( - linalgOp.getTiedIndexingMap(opOperand)); + linalgOp.getMatchingIndexingMap(opOperand)); readType = VectorType::get(commonVectorShape, getElementTypeOrSelf(opOperand->get())); } else { map = inversePermutation( - reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand))); + reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand))); readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), getElementTypeOrSelf(opOperand->get())); } @@ -545,7 +545,7 @@ return failure(); } for (OpOperand *opOperand : op.getOutputOperands()) { - AffineMap indexingMap = op.getTiedIndexingMap(opOperand); + AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); if (indexingMap.isPermutation()) continue; diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -180,7 +180,7 @@ // TODO: relax the restrictions on indexing map. for (OpOperand *opOperand : op.getOutputOperands()) { - if (!op.getTiedIndexingMap(opOperand).isPermutation()) + if (!op.getMatchingIndexingMap(opOperand).isPermutation()) return false; } return hasOnlyScalarElementwiseOp(op->getRegion(0)); @@ -967,7 +967,7 @@ for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { Value shapedOp = valuesToTile[opOperand->getOperandNumber()]; LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); - AffineMap map = linalgOp.getTiedIndexingMap(opOperand); + AffineMap map = linalgOp.getMatchingIndexingMap(opOperand); // Use `opOperand` as is if it is not tiled and not an output tensor. Having // an extract/insert slice pair for all output tensors simplifies follow up // transformations such as padding and bufferization since the diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -170,9 +170,9 @@ if (!op.hasTensorSemantics() || op.getNumInputs() != 2 || op.getNumResults() != 1 || op.getNumParallelLoops() != op.getNumLoops() || - !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() || - !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() || - !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity()) + !op.getMatchingIndexingMap(op.getOutputOperand(0)).isIdentity() || + !op.getMatchingIndexingMap(op.getInputOperand(0)).isIdentity() || + !op.getMatchingIndexingMap(op.getInputOperand(1)).isIdentity()) return failure(); // Find consuming OP2(sparse, other) or OP2(other, sparse). The other // operand can be sparse or dense, since the point of this rewriting rule diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -195,7 +195,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { bool annotated = false; for (OpOperand *t : op.getInputAndOutputOperands()) { - auto map = op.getTiedIndexingMap(t); + auto map = op.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); if (enc) annotated = true; @@ -296,7 +296,7 @@ if (t == skip) continue; // Get map and encoding. - auto map = op.getTiedIndexingMap(t); + auto map = op.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); assert(map.getNumDims() == n); // Skip dense tensor constraints when not requested. @@ -542,7 +542,7 @@ for (OpOperand *t : op.getInputAndOutputOperands()) { unsigned tensor = t->getOperandNumber(); auto shape = op.getShape(t); - auto map = op.getTiedIndexingMap(t); + auto map = op.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); // Scan all dimensions of current tensor. args.clear(); @@ -721,7 +721,7 @@ /// Generates index for load/store on sparse tensor. static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { - auto map = op.getTiedIndexingMap(t); + auto map = op.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1)); assert(a.getKind() == AffineExprKind::DimId); @@ -734,7 +734,7 @@ linalg::GenericOp op, OpOperand *t, SmallVector &args) { unsigned tensor = t->getOperandNumber(); - auto map = op.getTiedIndexingMap(t); + auto map = op.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); unsigned rank = map.getNumResults(); if (enc) { @@ -1079,7 +1079,7 @@ // Inspect tensor indices. bool atLevel = ldx == -1u; OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; - auto map = op.getTiedIndexingMap(t); + auto map = op.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr a = map.getResult(toOrigDim(enc, d)); @@ -1275,7 +1275,7 @@ unsigned idx) { for (OpOperand *t : op.getInputAndOutputOperands()) { if (!getSparseTensorEncoding(t->get().getType())) { - auto map = op.getTiedIndexingMap(t); + auto map = op.getMatchingIndexingMap(t); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr a = map.getResult(d); // Report non-unit stride if innermost index appears at an outer @@ -1920,7 +1920,8 @@ auto srcTp = tval.getType().cast(); auto dstEnc = SparseTensorEncodingAttr::get( op->getContext(), srcEnc.getDimLevelType(), - permute(getContext(), op.getTiedIndexingMap(t), topSort), // new order + permute(getContext(), op.getMatchingIndexingMap(t), + topSort), // new order srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); auto dstTp = RankedTensorType::get(srcTp.getShape(), srcTp.getElementType(), dstEnc);