diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -89,6 +89,8 @@ /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. /// Assumes `op` is a LinalgOp. +void getDimsOfType(Operation *op, StringRef iteratorTypeName, + SmallVectorImpl &res); void getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl &res); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -158,7 +158,7 @@ }], /*retTy=*/"void", /*methodName=*/"getParallelDims", - /*args=*/(ins "SmallVectorImpl &":$res), + /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ return getDimsOfType($_op, getParallelIteratorTypeName(), res); @@ -183,7 +183,7 @@ }], /*retTy=*/"void", /*methodName=*/"getReductionDims", - /*args=*/(ins "SmallVectorImpl &":$res), + /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ return getDimsOfType($_op, getReductionIteratorTypeName(), res); @@ -208,7 +208,7 @@ }], /*retTy=*/"void", /*methodName=*/"getWindowDims", - /*args=*/(ins "SmallVectorImpl &":$res), + /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ return getDimsOfType($_op.getOperation(), getWindowIteratorTypeName(), res); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -612,7 +612,7 @@ << indexingMap.getNumResults() << ")"; } - SmallVector redDims; + SmallVector redDims; linalgOp.getReductionDims(redDims); // Simplifying assumption: either full tensor or full buffer mode. @@ -639,8 +639,7 @@ for (OpOperand *opOperand : linalgOp.getOutputOperands()) { AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); for (auto expr : indexingMap.getResults()) { - for (auto dim : redDims) { - unsigned pos = dim.cast().getPosition(); + for (auto pos : redDims) { if (expr.isFunctionOfDim(pos)) { std::string exprStr; { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2318,19 +2318,27 @@ /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. /// Assumes `op` is a LinalgOp. void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, - SmallVectorImpl &res) { + SmallVectorImpl &res) { if (!cast(op).iterator_types()) return; unsigned dim = 0; - MLIRContext *ctx = op->getContext(); for (auto tn : cast(op).iterator_types().getAsValueRange()) { if (tn == iteratorTypeName) - res.push_back(getAffineDimExpr(dim, ctx)); + res.push_back(dim); ++dim; } } +void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, + SmallVectorImpl &res) { + SmallVector uRes; + getDimsOfType(op, iteratorTypeName, uRes); + res.resize(uRes.size()); + MLIRContext *ctx = op->getContext(); + llvm::transform(uRes, res.begin(), + [&ctx](unsigned dim) { return getAffineDimExpr(dim, ctx); }); +} AffineMap mlir::linalg::extractOrIdentityMap(Optional maybeMap, unsigned rank,