diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -180,9 +180,19 @@ if (failed(paddingValue)) return failure(hasDynamicShape); + auto operand = opOperand; + while (auto op = operand->get().getDefiningOp()) { + for (auto en : llvm::enumerate(op->getResults())) { + if (en.value() == operand->get()) { + operand = op.getOutputOperand(en.index()); + break; + } + } + } + // Cannot construct a static bounding box if the operand is not defined by an // ExtractSliceOp. - auto sliceOp = opOperand->get().getDefiningOp(); + auto sliceOp = operand->get().getDefiningOp(); if (!sliceOp) return failure(hasDynamicShape);