diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1087,6 +1087,44 @@ getNumInputs() + resultIdx, dim); }] >, + InterfaceMethod< + /*desc=*/[{ + Like `getShape`, but only returns statically-known information, without + generating any new IR. For each shape dimension, returns >=0 if that + dimension is statically known, or ShapeType::kDynamicSize otherwise. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticShape", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector res; + for (Value v : getShapedOperands()) { + auto shape = v.getType().cast().getShape(); + res.append(shape.begin(), shape.end()); + } + return res; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the statically-known loop ranges. Composes + `getShapesToLoopsMap()` with the result of `getStaticShape`. + Returns None if `getShapesToLoopsMap()` fails. Returns + ShapeType::kDynamicSize for non-statically-known loop ranges. + }], + /*retTy=*/"Optional>", + /*methodName=*/"getStaticLoopRanges", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector viewSizes = getStaticShape(); + AffineMap invertedMap = getShapesToLoopsMap(); + if (!invertedMap) + return {}; + return invertedMap.compose(viewSizes); + }] + >, //===------------------------------------------------------------------===// // Other static interface methods. diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -118,17 +118,6 @@ Optional> fuseTensorOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand); -/// Like `getShape`, but only returns statically-known information, without -/// generating any new IR. For each shape dimension, returns >=0 if that -/// dimension is statically known, or -1 otherwise. -SmallVector getStaticShape(LinalgOp linalgOp); - -/// Returns the statically-known loop ranges of the `linalgOp`. Composes -/// `linalgOp.getShapesToLoopsMap()` with the result of `getStaticShape`. -/// Returns None if `linalgOp.getShapesToLoopsMap()` fails. Returns -1 -/// for non-statically-known loop ranges. -Optional> getStaticLoopRanges(LinalgOp linalgOp); - /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -499,7 +499,7 @@ AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); Optional> originalLoopRange = - getStaticLoopRanges(linalgOp); + linalgOp.getStaticLoopRanges(); if (!originalLoopRange) return linalgOp.emitError("unable to find loop range for operation"); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -98,23 +98,6 @@ namespace mlir { namespace linalg { -SmallVector getStaticShape(LinalgOp linalgOp) { - SmallVector res; - for (Value v : linalgOp.getShapedOperands()) { - auto shape = v.getType().cast().getShape(); - res.append(shape.begin(), shape.end()); - } - return res; -} - -Optional> getStaticLoopRanges(LinalgOp linalgOp) { - SmallVector viewSizes = getStaticShape(linalgOp); - AffineMap invertedMap = linalgOp.getShapesToLoopsMap(); - if (!invertedMap) - return {}; - return invertedMap.compose(viewSizes); -} - /// If `size` comes from an AffineMinOp and one of the values of AffineMinOp /// is a constant then return a new value set to the smallest such constant. /// Otherwise returngetSmallestBoundingIndex nullptr. diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -259,80 +259,3 @@ // CHECK-NEXT: %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32 // CHECK-NEXT: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32 // CHECK-NEXT: linalg.yield %[[ADD]] : f32 - -// ----- - -func @pooling_nhwc_sum(%input: memref, %fake: memref<2x3xf32>, %init: memref) { - linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%input, %fake: memref, memref<2x3xf32>) - outs(%init: memref) - return -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - -// CHECK: func @pooling_nhwc_sum - -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref<2x3xf32>) -// CHECK-SAME: outs(%{{.+}} : memref) - -// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) -// CHECK-NEXT: %[[RES:.+]] = addf %[[BBARG2]], %[[BBARG0]] : f32 -// CHECK-NEXT: linalg.yield %[[RES]] : f32 - -// ----- - -func @pooling_nhwc_max(%input: memref, %fake: memref<2x3xf32>, %init: memref) { - linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>} - ins(%input, %fake: memref, memref<2x3xf32>) - outs(%init: memref) - return -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - -// CHECK: func @pooling_nhwc_max - -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref<2x3xf32>) -// CHECK-SAME: outs(%{{.+}} : memref) - -// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) -// CHECK-NEXT: %[[CMP:.+]] = cmpf ogt, %[[BBARG0]], %[[BBARG2]] : f32 -// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : f32 -// CHECK-NEXT: linalg.yield %[[RES]] : f32 - -// ----- - -func @pooling_nhwc_min(%input: memref, %fake: memref<2x3xf32>, %init: memref) { - linalg.pooling_nhwc_min {dilations = dense<3> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} - ins(%input, %fake: memref, memref<2x3xf32>) - outs(%init: memref) - return -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4 * 3, d2 * 2 + d5 * 3, d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - -// CHECK: func @pooling_nhwc_min - -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref<2x3xf32>) -// CHECK-SAME: outs(%{{.+}} : memref) - -// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) -// CHECK-NEXT: %[[CMP:.+]] = cmpf olt, %[[BBARG0]], %[[BBARG2]] : f32 -// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : f32 -// CHECK-NEXT: linalg.yield %[[RES]] : f32