diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -375,11 +375,12 @@ void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} static LogicalResult verify(CopyOp op) { - auto outputViewType = op.getOutputShapedType(0); - auto inputViewType = op.getInputShapedType(0); - if (inputViewType.getElementType() != outputViewType.getElementType()) + OpOperand *output = op.getOutputOperand(0); + OpOperand *input = op.getInputOperand(0); + if (getElementTypeOrSelf(input->get().getType()) != + getElementTypeOrSelf(output->get().getType())) return op.emitOpError("expects views of the same type"); - if (inputViewType.getRank() != outputViewType.getRank()) + if (op.getRank(input) != op.getRank(output)) return op.emitOpError("expects views of the same rank"); auto rank = op.getNumParallelLoops(); auto inputPermutationMap = op.inputPermutation(); @@ -449,11 +450,11 @@ void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {} static LogicalResult verify(FillOp op) { - auto viewType = op.getOutputShapedType(0); - auto fillType = op.value().getType(); - if (viewType.getElementType() != fillType) + OpOperand *output = op.getOutputOperand(0); + Type fillType = op.value().getType(); + if (getElementTypeOrSelf(output->get().getType()) != fillType) return op.emitOpError("expects fill type to match view elemental type"); - if (!op.getNumResults() && !viewType.isa()) { + if (!op.getNumResults() && !output->get().getType().isa()) { return op.emitOpError( "expected fill op with no result value to use memref type"); } @@ -739,11 +740,13 @@ // Create a generic replacement operation and clone the body. rewriter.setInsertionPointAfter(indexedOp); + SmallVector inputOperands = indexedOp.getInputOperands(); + SmallVector outputOperands = indexedOp.getOutputOperands(); SmallVector iterators = llvm::to_vector<4>( indexedOp.iterator_types().getAsValueRange()); GenericOp genericOp = rewriter.create( - indexedOp.getLoc(), indexedOp->getResultTypes(), indexedOp.getInputs(), - indexedOp.getOutputs(), indexedOp.getIndexingMaps(), iterators); + indexedOp.getLoc(), indexedOp->getResultTypes(), inputOperands, + outputOperands, indexedOp.getIndexingMaps(), iterators); Region &genericRegion = genericOp.region(); Region &indexedRegion = indexedOp.region(); rewriter.cloneRegionBefore(indexedRegion, genericRegion, @@ -2107,21 +2110,21 @@ // Check the operand number and types must match the element types of the // LinalgOp interface's shaped operands. -static LogicalResult verifyYield(linalg::YieldOp op, - LinalgOp linalgOpInterface) { - auto nOutputs = linalgOpInterface.getNumOutputs(); - if (op.getNumOperands() != nOutputs) +static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { + if (op.getNumOperands() != linalgOp.getNumOutputs()) return op.emitOpError("expected number of yield values (") - << nOutputs << ") to match the number of operands of the enclosing " + << linalgOp.getNumOutputs() + << ") to match the number of operands of the enclosing " << "LinalgOp (" << op.getNumOperands() << ")"; - for (unsigned i = 0; i != nOutputs; ++i) { - auto elementType = - linalgOpInterface.getOutputShapedType(i).getElementType(); - if (op.getOperand(i).getType() != elementType) + for (OpOperand &opOperand : op->getOpOperands()) { + OpOperand *outputOperand = + linalgOp.getOutputOperand(opOperand.getOperandNumber()); + Type elementType = getElementTypeOrSelf(outputOperand->get().getType()); + if (opOperand.get().getType() != elementType) return op.emitOpError("type of yield operand ") - << (i + 1) << " (" << op.getOperand(i).getType() - << ") doesn't match " + << (opOperand.getOperandNumber() + 1) << " (" + << opOperand.get().getType() << ") doesn't match " << "the element type of the enclosing linalg.generic op (" << elementType << ")"; } @@ -3096,14 +3099,14 @@ LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - for (Value v : op.getShapedOperands()) { + for (OpOperand *opOperand : op.getInputAndOutputOperands()) { // Linalg "inputs" may be either tensor or memref type. // tensor<0xelt_type> is a convention that may not always mean // "0 iterations". Only erase in cases we see memref<...x0x...>. - auto mt = v.getType().dyn_cast(); + auto mt = opOperand->get().getType().dyn_cast(); if (!mt) continue; - if (llvm::is_contained(mt.getShape(), 0)) { + if (llvm::is_contained(op.getShape(opOperand), 0)) { rewriter.eraseOp(op); return success(); } @@ -3119,10 +3122,10 @@ PatternRewriter &rewriter) const override { // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = - llvm::any_of(op.getShapedOperands(), [&](Value v) { - if (v.isa()) + llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { + if (opOperand->get().isa()) return false; - auto castOp = v.getDefiningOp(); + auto castOp = opOperand->get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); }); if (!hasTensorCastOperand) @@ -3133,16 +3136,18 @@ SmallVector newOperands; newOperands.reserve(op->getNumOperands()); // Inputs may fold. - for (Value v : op.getInputs()) { - auto tensorCastOp = v.getDefiningOp(); - newOperands.push_back( - canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v); + for (OpOperand *opOperand : op.getInputOperands()) { + auto tensorCastOp = opOperand->get().getDefiningOp(); + newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) + ? tensorCastOp.source() + : opOperand->get()); } // Init tensors may fold, in which case the resultType must also change. - for (Value v : op.getOutputs()) { - auto tensorCastOp = v.getDefiningOp(); + for (OpOperand *opOperand : op.getOutputOperands()) { + auto tensorCastOp = opOperand->get().getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); - newOperands.push_back(fold ? tensorCastOp.getOperand() : v); + newOperands.push_back(fold ? tensorCastOp.getOperand() + : opOperand->get()); newResultTypes.push_back(newOperands.back().getType()); } auto extraOperands = op.getAssumedNonShapedOperands(); @@ -3189,18 +3194,18 @@ // in the case of duplicated inputs, the canonical input could be some other // input `< i`. That is, a later input will have some earlier input as its // canonical input. - llvm::SmallDenseMap, int> canonicalInput; + llvm::SmallDenseMap, unsigned> canonicalInput; // For later remapping tasks like deduplicating payload block arguments, // having a simple "inputIndex -> canonicalInputIndex" integer mapping is // convenient. - SmallVector canonicalInputIndices; - for (int i = 0, e = op.getNumInputs(); i != e; i++) { - Value input = op.getInput(i); - AffineMap indexingMap = op.getInputIndexingMap(i); + SmallVector canonicalInputIndices; + for (OpOperand *opOperand : op.getInputOperands()) { + AffineMap indexingMap = op.getTiedIndexingMap(opOperand); // STL-like maps have a convenient behavior for our use case here. In the // case of duplicate keys, the insertion is rejected, and the returned // iterator gives access to the value already in the map. - auto pair = canonicalInput.insert({{input, indexingMap}, i}); + auto pair = canonicalInput.insert( + {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); canonicalInputIndices.push_back(pair.first->second); } @@ -3209,26 +3214,29 @@ return failure(); // The operands for the newly canonicalized op. - SmallVector newOperands; - for (auto v : llvm::enumerate(op.getInputs())) - if (canonicalInputIndices[v.index()] == static_cast(v.index())) - newOperands.push_back(v.value()); - llvm::append_range(newOperands, op.getOutputs()); + SmallVector newOperands; + for (OpOperand *opOperand : op.getInputOperands()) + if (canonicalInputIndices[opOperand->getOperandNumber()] == + opOperand->getOperandNumber()) + newOperands.push_back(opOperand->get()); + SmallVector outputOperands = op.getOutputOperands(); + llvm::append_range(newOperands, outputOperands); llvm::append_range(newOperands, op.getAssumedNonShapedOperands()); + // Repair the indexing maps by filtering out the ones that have been + // eliminated. + SmallVector newIndexingMaps; + for (OpOperand *opOperand : op.getInputOperands()) + if (canonicalInputIndices[opOperand->getOperandNumber()] == + opOperand->getOperandNumber()) + newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand)); + for (OpOperand *opOperand : op.getOutputOperands()) + newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand)); + // Clone the old op with new operands. Operation *newOp = op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands); auto newLinalgOp = cast(newOp); - - // Repair the indexing maps by filtering out the ones that have been - // eliminated. - SmallVector newIndexingMaps; - for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++) - if (canonicalInputIndices[i] == i) - newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i)); - for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++) - newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i)); newOp->setAttr("indexing_maps", rewriter.getAffineMapArrayAttr(newIndexingMaps)); @@ -3243,18 +3251,18 @@ // Repair the payload entry block by RAUW'ing redundant arguments and // erasing them. Block &payload = newOp->getRegion(0).front(); - for (int i = 0, e = op.getNumInputs(); i < e; i++) { + SmallVector inputOperands = op.getInputOperands(); + for (OpOperand *opOperand : llvm::reverse(inputOperands)) { // Iterate in reverse, so that we erase later args first, preventing the // argument list from shifting unexpectedly and invalidating all our // indices. - int reversed = e - i - 1; - int canonicalIndex = canonicalInputIndices[reversed]; - if (canonicalInputIndices[reversed] == reversed) + unsigned operandNumber = opOperand->getOperandNumber(); + if (canonicalInputIndices[operandNumber] == operandNumber) continue; - payload.getArgument(bbArgBaseOffset + reversed) - .replaceAllUsesWith( - payload.getArgument(bbArgBaseOffset + canonicalIndex)); - payload.eraseArgument(bbArgBaseOffset + reversed); + payload.getArgument(bbArgBaseOffset + operandNumber) + .replaceAllUsesWith(payload.getArgument( + bbArgBaseOffset + canonicalInputIndices[operandNumber])); + payload.eraseArgument(bbArgBaseOffset + operandNumber); } rewriter.replaceOp(op, newOp->getResults());