diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -67,6 +67,15 @@ for (unsigned i = 0, e = t.getRank(); i < e; ++i) res.push_back(b.create(loc, v, i)); } + if (getNumInitTensors() == 0 && getOperation()->getNumResults() != 0) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(getOperation()); + for (Value v : getOperation()->getResults()) { + ShapedType t = v.getType().template cast(); + for (unsigned i = 0, e = t.getRank(); i < e; ++i) + res.push_back(b.create(loc, v, i)); + } + } return res; } @@ -165,14 +174,6 @@ HasAffineDimExprVisitor checkDimExpr(outputDims); if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0))) return llvm::None; - - // Since map.compose(createFlatListOfOperands()) only works for the case where - // init tensors exist, drop the dims corresponding to output shapes in the - // map. - if (getNumInitTensors() == 0 && getOperation()->getNumResults() != 0) { - operandShapesToResultDimMap = getProjectedMap( - operandShapesToResultDimMap, llvm::to_vector<4>(outputDims)); - } return applyMapToValues(b, loc, operandShapesToResultDimMap, createFlatListOfOperandDims(b, loc))[0]; }