diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1087,6 +1087,44 @@ getNumInputs() + resultIdx, dim); }] >, + InterfaceMethod< + /*desc=*/[{ + Like `getShape`, but only returns statically-known information, without + generating any new IR. For each shape dimension, returns >=0 if that + dimension is statically known, or ShapeType::kDynamicSize otherwise. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticShape", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector res; + for (Value v : getShapedOperands()) { + auto shape = v.getType().cast().getShape(); + res.append(shape.begin(), shape.end()); + } + return res; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the statically-known loop ranges. Composes + `getShapesToLoopsMap()` with the result of `getStaticShape`. + Returns None if `getShapesToLoopsMap()` fails. Returns ShapeType::kDynamicSize + for non-statically-known loop ranges. + }], + /*retTy=*/"Optional>", + /*methodName=*/"getStaticLoopRanges", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector viewSizes = getStaticShape(); + AffineMap invertedMap = getShapesToLoopsMap(); + if (!invertedMap) + return {}; + return invertedMap.compose(viewSizes); + }] + >, //===------------------------------------------------------------------===// // Other static interface methods. diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -118,17 +118,6 @@ Optional> fuseTensorOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand); -/// Like `getShape`, but only returns statically-known information, without -/// generating any new IR. For each shape dimension, returns >=0 if that -/// dimension is statically known, or -1 otherwise. -SmallVector getStaticShape(LinalgOp linalgOp); - -/// Returns the statically-known loop ranges of the `linalgOp`. Composes -/// `linalgOp.getShapesToLoopsMap()` with the result of `getStaticShape`. -/// Returns None if `linalgOp.getShapesToLoopsMap()` fails. Returns -1 -/// for non-statically-known loop ranges. -Optional> getStaticLoopRanges(LinalgOp linalgOp); - /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -499,7 +499,7 @@ AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); Optional> originalLoopRange = - getStaticLoopRanges(linalgOp); + linalgOp.getStaticLoopRanges(); if (!originalLoopRange) return linalgOp.emitError("unable to find loop range for operation"); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -98,23 +98,6 @@ namespace mlir { namespace linalg { -SmallVector getStaticShape(LinalgOp linalgOp) { - SmallVector res; - for (Value v : linalgOp.getShapedOperands()) { - auto shape = v.getType().cast().getShape(); - res.append(shape.begin(), shape.end()); - } - return res; -} - -Optional> getStaticLoopRanges(LinalgOp linalgOp) { - SmallVector viewSizes = getStaticShape(linalgOp); - AffineMap invertedMap = linalgOp.getShapesToLoopsMap(); - if (!invertedMap) - return {}; - return invertedMap.compose(viewSizes); -} - /// If `size` comes from an AffineMinOp and one of the values of AffineMinOp /// is a constant then return a new value set to the smallest such constant. /// Otherwise returngetSmallestBoundingIndex nullptr.