diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -528,10 +528,10 @@ // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); Value zero = b.create(loc, 0); - for (OpOperand &opOperand : linalgOp->getOpOperands()) { - BlockArgument bbarg = block->getArgument(opOperand.getOperandNumber()); - if (linalgOp.isScalar(&opOperand)) { - bvm.map(bbarg, opOperand.get()); + for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { + BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand); + if (linalgOp.isScalar(opOperand)) { + bvm.map(bbarg, opOperand->get()); continue; } VectorType readType; @@ -540,23 +540,23 @@ // if (linalgOp.getShape(&opOperand).empty()) { // readType = VectorType::get({}, bbarg.getType()); // } else { - if (opOperand.getOperandNumber() < linalgOp.getNumInputs()) { + if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { map = inverseAndBroadcastProjectedPermutation( - linalgOp.getMatchingIndexingMap(&opOperand)); + linalgOp.getMatchingIndexingMap(opOperand)); readType = VectorType::get(commonVectorShape, - getElementTypeOrSelf(opOperand.get())); + getElementTypeOrSelf(opOperand->get())); } else { map = inversePermutation( - reindexIndexingMap(linalgOp.getMatchingIndexingMap(&opOperand))); - readType = VectorType::get(map.compose(linalgOp.getShape(&opOperand)), - getElementTypeOrSelf(opOperand.get())); + reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand))); + readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), + getElementTypeOrSelf(opOperand->get())); } // } - auto shape = linalgOp.getShape(&opOperand); + auto shape = linalgOp.getShape(opOperand); SmallVector indices(shape.size(), zero); Value readValue = b.create( - loc, readType, opOperand.get(), indices, map); + loc, readType, opOperand->get(), indices, map); // Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readValue.getType().cast().getRank() == 0) @@ -564,7 +564,7 @@ LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); bvm.map(bbarg, readValue); - bvm.map(opOperand.get(), readValue); + bvm.map(opOperand->get(), readValue); } SmallVector hooks;