diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -119,28 +119,26 @@ LinalgOp linalgOp) { assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); - unsigned nInputs = linalgOp.getNumInputs(); - unsigned nOutputs = linalgOp.getNumOutputs(); SmallVector indexedValues; - indexedValues.reserve(nInputs + nOutputs); + indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); auto allIvsPlusDims = SmallVector(allIvs.begin(), allIvs.end()); // TODO: Avoid the loads if the corresponding argument of the // region has no uses. // 1.a. Emit load from input views. - for (unsigned i = 0; i < nInputs; ++i) { + for (OpOperand *inputOperand : linalgOp.getInputOperands()) { auto indexing = makeCanonicalAffineApplies( - b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims); + b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims); indexedValues.push_back( - b.create(loc, linalgOp.getInput(i), indexing)); + b.create(loc, inputOperand->get(), indexing)); } // 1.b. Emit load from output views. - for (unsigned i = 0; i < nOutputs; ++i) { - auto indexing = makeCanonicalAffineApplies( - b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims); + for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { + SmallVector indexing = makeCanonicalAffineApplies( + b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims); indexedValues.push_back( - b.create(loc, linalgOp.getOutputBuffer(i), indexing)); + b.create(loc, outputOperand->get(), indexing)); } // TODO: When a region inliner exists, use it. @@ -148,10 +146,10 @@ // 3. Emit store. SmallVector, 8> indexing; SmallVector outputBuffers; - for (unsigned i = 0; i < nOutputs; ++i) { + for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) { indexing.push_back(makeCanonicalAffineApplies( - b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims)); - outputBuffers.push_back(linalgOp.getOutputBuffer(i)); + b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims)); + outputBuffers.push_back(outputOperand->get()); } inlineRegionAndEmitStore(b, loc, linalgOp, indexedValues, indexing, outputBuffers);