diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -145,18 +145,20 @@ : subViews(), dynamicBuffers(options.dynamicBuffers), alignment(options.alignment) { assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand"); - int64_t nBuffers = linalgOp.getNumShapedOperands(); auto vUseFullTileBuffers = options.useFullTileBuffers.getValueOr(llvm::SmallBitVector()); - vUseFullTileBuffers.resize(nBuffers, options.useFullTileBuffersDefault); + vUseFullTileBuffers.resize(linalgOp.getNumInputsAndOutputs(), + options.useFullTileBuffersDefault); - for (int64_t idx = 0; idx != nBuffers; ++idx) { - if (options.operandsToPromote && !options.operandsToPromote->count(idx)) + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + int64_t operandNumber = opOperand->getOperandNumber(); + if (options.operandsToPromote && + !options.operandsToPromote->count(operandNumber)) continue; - auto *op = linalgOp.getShapedOperand(idx).getDefiningOp(); + Operation *op = opOperand->get().getDefiningOp(); if (auto sv = dyn_cast_or_null(op)) { - subViews[idx] = sv; - useFullTileBuffers[sv] = vUseFullTileBuffers[idx]; + subViews[operandNumber] = sv; + useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber]; } } @@ -318,23 +320,24 @@ // operands are not views. This is to support cases such as FillOp taking // extra scalars etc. Keep a reference to output buffers; SmallVector opViews; - opViews.reserve(op.getNumShapedOperands()); + opViews.reserve(op.getNumInputsAndOutputs()); SmallVector, 8> writebackViews; writebackViews.reserve(promotedBuffersAndViews->size()); - for (auto view : llvm::enumerate(op.getShapedOperands())) { - if (options.subViews.count(view.index()) != 0) { - if (options.useFullTileBuffers[view.value()]) + for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + int64_t operandNumber = opOperand->getOperandNumber(); + if (options.subViews.count(operandNumber) != 0) { + if (options.useFullTileBuffers[opOperand->get()]) opViews.push_back( - (*promotedBuffersAndViews)[view.index()].fullLocalView); + (*promotedBuffersAndViews)[operandNumber].fullLocalView); else opViews.push_back( - (*promotedBuffersAndViews)[view.index()].partialLocalView); - if (static_cast(view.index()) >= op.getNumInputs()) + (*promotedBuffersAndViews)[operandNumber].partialLocalView); + if (operandNumber >= op.getNumInputs()) writebackViews.emplace_back(std::make_pair( - view.value(), - (*promotedBuffersAndViews)[view.index()].partialLocalView)); + opOperand->get(), + (*promotedBuffersAndViews)[operandNumber].partialLocalView)); } else { - opViews.push_back(view.value()); + opViews.push_back(opOperand->get()); } } op->setOperands(0, opViews.size(), opViews); @@ -357,16 +360,17 @@ LogicalResult mlir::linalg::promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options) { - LinalgOp linOp = dyn_cast(op); + LinalgOp linalgOp = dyn_cast(op); // Transformation applies to buffers only. - if (!linOp || !linOp.hasBufferSemantics()) + if (!linalgOp || !linalgOp.hasBufferSemantics()) return failure(); // Check that at least one of the requested operands is indeed a subview. - for (auto en : llvm::enumerate(linOp.getShapedOperands())) { - auto sv = isa_and_nonnull(en.value().getDefiningOp()); + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + auto sv = + isa_and_nonnull(opOperand->get().getDefiningOp()); if (sv) { if (!options.operandsToPromote.hasValue() || - options.operandsToPromote->count(en.index())) + options.operandsToPromote->count(opOperand->getOperandNumber())) return success(); } }