diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -114,12 +114,23 @@ return getShape(builder, cast(linalgOp.getOperation())); } +/// Like `getShape`, but only returns statically-known information, without +/// generating any new IR. For each shape dimension, returns >=0 if that +/// dimension is statically known, or -1 otherwise. +SmallVector getStaticShape(LinalgOp linalgOp); + /// Returns the loop ranges of the `linalgOp`. Applies the inverse of the /// concatenated indexing maps to the result of `getShape`. Returns None if /// the bounds computation fails. Optional> getLoopRanges(OpBuilder &builder, LinalgOp linalgOp); +/// Returns the statically-known loop ranges of the `linalgOp`. Applies the +/// inverse of the concatenated indexing maps to the result of `getStaticShape`. +/// Returns None if inverting the concatenated indexing map fails. Returns -1 +/// for non-statically-known loop ranges. +Optional> getStaticLoopRanges(LinalgOp linalgOp); + /// Returns the values obtained by applying `map` to the list of values. SmallVector applyMapToValues(OpBuilder &b, Location loc, AffineMap map, ValueRange values); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -156,6 +156,15 @@ return res; } +SmallVector getStaticShape(LinalgOp linalgOp) { + SmallVector res; + for (Value v : linalgOp.getShapedOperands()) { + auto shape = v.getType().template cast().getShape(); + res.append(shape.begin(), shape.end()); + } + return res; +} + Optional> getLoopRanges(OpBuilder &builder, LinalgOp linalgOp) { SmallVector viewSizes = getShape(builder, linalgOp); @@ -166,6 +175,15 @@ return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes); } +Optional> getStaticLoopRanges(LinalgOp linalgOp) { + SmallVector viewSizes = getStaticShape(linalgOp); + AffineMap invertedMap = + inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps())); + if (!invertedMap) + return {}; + return invertedMap.compose(viewSizes); +} + /// Specialization to build an scf "for" nest. template <> void GenerateLoopNest::doit(