diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -16,6 +16,15 @@ #include "mlir/Pass/Pass.h" namespace mlir { + +class AffineDialect; +namespace tensor { +class TensorDialect; +} // namespace tensor +namespace vector { +class VectorDialect; +} // namespace vector + namespace memref { //===----------------------------------------------------------------------===// @@ -26,6 +35,11 @@ /// into `patterns`. void populateFoldSubViewOpPatterns(RewritePatternSet &patterns); +/// Appends patterns that resolve `memref.dim` operations with values that are +/// defined by operations that implement the `InferShapedTypeOpInterface`, in +/// terms of shapes of its input operands. +void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// @@ -34,6 +48,11 @@ /// load/store ops into `patterns`. std::unique_ptr createFoldSubViewOpsPass(); +/// Creates an operation pass to resolve `memref.dim` operations with values +/// that are defined by operations that implement the +/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands. +std::unique_ptr createResolveShapedTypeResultDimsPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -23,6 +23,18 @@ ]; } +def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> { + let summary = "Resolve memref.dim of result values"; + let description = [{ + The pass resolves memref.dim of result of operations that + implement the `InferShapedTypeOpInterface` in terms of shapes of + its operands. + }]; + let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()"; + let dependentDialects = [ + "memref::MemRefDialect", "tensor::TensorDialect" + ]; +} #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -794,84 +794,12 @@ return success(); } }; - -/// Helper method to get the `Value` that is the shape of the `resultIdx`-th -/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. -/// TODO(ravishankarm): This is better put as a interface utility method -/// somewhere, but that would imply the interface will depend on the `tensor` -/// dialect. Ideally maybe a utility method in the `tensor` dialect. -static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, - int64_t dimIndex) { - unsigned resultNumber = result.getResultNumber(); - auto shapedTypeOp = dyn_cast(result.getOwner()); - Location loc = result.getOwner()->getLoc(); - if (!shapedTypeOp) - return nullptr; - - // The interface exposes two methods, one that returns the shape of all the - // results as `Value` and other that returns the shape as a list of - // `SmallVector`. The former takes precedence over the latter. So first - // check if the op implements the first interface method or the second, and - // get the value to use appropriately. - SmallVector reifiedResultShapes; - if (succeeded(shapedTypeOp.reifyReturnTypeShapes( - builder, result.getOwner()->getOperands(), reifiedResultShapes))) { - if (reifiedResultShapes.size() <= resultNumber) - return nullptr; - Value resultShape = reifiedResultShapes[resultNumber]; - auto resultShapeType = resultShape.getType().dyn_cast(); - if (!resultShapeType || !resultShapeType.getElementType().isa()) - return nullptr; - return builder.create( - loc, resultShape, builder.createOrFold(loc, dimIndex)); - } - - SmallVector> reifiedResultShapesPerDim; - if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( - builder, reifiedResultShapesPerDim))) - return nullptr; - if (reifiedResultShapesPerDim.size() <= resultNumber || - reifiedResultShapesPerDim[resultNumber].size() != - static_cast(result.getType().cast().getRank())) - return nullptr; - OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; - if (auto attr = valueOrAttr.dyn_cast()) - return builder.createOrFold( - loc, attr.cast().getInt()); - return valueOrAttr.get(); -} - -/// Fold dim of an operation that implements the InferShapedTypeOpInterface -struct DimOfShapedTypeOpInterface : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - OpResult dimValue = dimOp.memrefOrTensor().dyn_cast(); - if (!dimValue) - return failure(); - auto shapedTypeOp = - dyn_cast(dimValue.getOwner()); - if (!shapedTypeOp) - return failure(); - - Optional dimIndex = dimOp.getConstantIndex(); - if (!dimIndex) - return failure(); - Value replacement = - getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); - if (!replacement) - return failure(); - rewriter.replaceOp(dimOp, replacement); - return success(); - } -}; } // end anonymous namespace. void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, - DimOfCastOp, DimOfShapedTypeOpInterface>(context); + DimOfCastOp>(context); } // --------------------------------------------------------------------------- diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRMemRefTransforms FoldSubViewOps.cpp + ResolveShapedTypeResultDims.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef @@ -9,9 +10,11 @@ LINK_LIBS PUBLIC MLIRAffine + MLIRInferTypeOpInterface MLIRMemRef MLIRPass MLIRStandard + MLIRTensor MLIRTransforms MLIRVector ) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -0,0 +1,127 @@ +//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values +//-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass resolves `memref.dim` operations of result values in terms of +// shapes of their operands using the `InferShapedTypeOpInterface`. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +/// Helper method to get the `Value` that is the shape of the `resultIdx`-th +/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. +/// TODO(ravishankarm): This is better put as a interface utility method +/// somewhere, but that would imply the interface will depend on the `tensor` +/// dialect. Ideally maybe a utility method in the `tensor` dialect. +static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, + int64_t dimIndex) { + unsigned resultNumber = result.getResultNumber(); + auto shapedTypeOp = dyn_cast(result.getOwner()); + Location loc = result.getOwner()->getLoc(); + if (!shapedTypeOp) + return nullptr; + + // The interface exposes two methods, one that returns the shape of all the + // results as `Value` and other that returns the shape as a list of + // `SmallVector`. The former takes precedence over the latter. So first + // check if the op implements the first interface method or the second, and + // get the value to use appropriately. + SmallVector reifiedResultShapes; + if (succeeded(shapedTypeOp.reifyReturnTypeShapes( + builder, result.getOwner()->getOperands(), reifiedResultShapes))) { + if (reifiedResultShapes.size() <= resultNumber) + return nullptr; + Value resultShape = reifiedResultShapes[resultNumber]; + auto resultShapeType = resultShape.getType().dyn_cast(); + if (!resultShapeType || !resultShapeType.getElementType().isa()) + return nullptr; + return builder.create( + loc, resultShape, builder.createOrFold(loc, dimIndex)); + } + + SmallVector> reifiedResultShapesPerDim; + if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( + builder, reifiedResultShapesPerDim))) + return nullptr; + if (reifiedResultShapesPerDim.size() <= resultNumber || + reifiedResultShapesPerDim[resultNumber].size() != + static_cast(result.getType().cast().getRank())) + return nullptr; + OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; + if (auto attr = valueOrAttr.dyn_cast()) + return builder.createOrFold( + loc, attr.cast().getInt()); + return valueOrAttr.get(); +} + +namespace { +/// Fold dim of an operation that implements the InferShapedTypeOpInterface +struct DimOfShapedTypeOpInterface : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DimOp dimOp, + PatternRewriter &rewriter) const override { + OpResult dimValue = dimOp.memrefOrTensor().dyn_cast(); + if (!dimValue) + return failure(); + auto shapedTypeOp = + dyn_cast(dimValue.getOwner()); + if (!shapedTypeOp) + return failure(); + + Optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) + return failure(); + Value replacement = + getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); + if (!replacement) + return failure(); + rewriter.replaceOp(dimOp, replacement); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_CLASSES +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" + +struct ResolveShapedTypeResultDimsPass final + : public ResolveShapedTypeResultDimsBase { + void runOnOperation() override; +}; +} // namespace + +void memref::populateResolveShapedTypeResultDimsPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +void ResolveShapedTypeResultDimsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateResolveShapedTypeResultDimsPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(), + std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr memref::createResolveShapedTypeResultDimsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -532,205 +532,6 @@ // ----- -func @init_tensor_static_dim() -> (index, index) { - %c0 = constant 0 : index - %c2 = constant 2 : index - %c6 = constant 6 : index - %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32> - %1 = memref.dim %0, %c2 : tensor<4x5x?xf32> - %2 = memref.dim %0, %c0 : tensor<4x5x?xf32> - return %1, %2 : index, index -} -// CHECK: func @init_tensor_static_dim -// CHECK-DAG: %[[C4:.+]] = constant 4 : index -// CHECK-DAG: %[[C6:.+]] = constant 6 : index -// CHECK: return %[[C6]], %[[C4]] - -// ----- - -func @init_tensor_dynamic_dim(%arg0 : index) -> (index) { - %c2 = constant 2 : index - %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32> - %1 = memref.dim %0, %c2 : tensor<4x5x?xf32> - return %1 : index -} -// CHECK: func @init_tensor_dynamic_dim -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index -// CHECK: return %[[ARG0]] - -// ----- - -func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = linalg.init_tensor [%arg0, %arg1] : tensor - %1 = memref.dim %0, %c0 : tensor - %2 = memref.dim %0, %c1 : tensor - return %1, %2 : index, index -} -// CHECK: func @init_tensor_dynamic_dim2 -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK: return %[[ARG0]], %[[ARG1]] - -// ----- - -func @remove_dim_result_uses - (%arg0 : tensor, %arg1 : tensor, - %arg2 : tensor) -> (index, index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = linalg.generic - {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>], - iterator_types = ["parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) { - ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): - %1 = mulf %arg3, %arg4 : f32 - %2 = addf %1, %arg5 : f32 - linalg.yield %2 : f32 - } -> tensor - %3 = memref.dim %0, %c0 : tensor - %4 = memref.dim %0, %c1 : tensor - return %3, %4 : index, index -} -// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> -// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)> -// CHECK: func @remove_dim_result_uses -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]] -// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]] -// CHECK: return %[[T2]], %[[T5]] - -// ----- - -func @remove_dim_result_uses_outs - (%arg0 : tensor, %arg1 : index) -> (index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %d0 = memref.dim %arg0, %c0 : tensor - %0 = linalg.init_tensor [%d0, %arg1] : tensor - %1 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : tensor) outs(%0 : tensor) { - ^bb0(%arg2: f32, %arg3: f32) : - linalg.yield %arg2 : f32 - } -> tensor - %2 = memref.dim %1, %c1 : tensor - return %2 : index -} -// CHECK: func @remove_dim_result_uses_outs -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK: return %[[ARG1]] - -// ----- - -func @remove_dim_result_uses_sequence - (%arg0 : tensor, %arg1 : tensor, - %arg2 : tensor) -> (index, index, index, index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor - %1 = memref.dim %0, %c0 : tensor - %2 = memref.dim %0, %c1 : tensor - %3 = linalg.generic - {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>, - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d0, d2)>], - iterator_types = ["parallel", "reduction", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%0 : tensor) { - ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): - %4 = mulf %arg3, %arg4 : f32 - %5 = addf %4, %arg5 : f32 - linalg.yield %5 : f32 - } -> tensor - %6 = memref.dim %3, %c0 : tensor - %7 = memref.dim %3, %c1 : tensor - return %1, %2, %6, %7 : index, index, index, index -} -// CHECK-LABEL: func @remove_dim_result_uses_sequence -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]] - -// ----- - -func @keep_result_dim_uses_sequence2 - (%arg0 : tensor, %arg1 : index) -> (index, index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %d0 = memref.dim %arg0, %c0 : tensor - %0 = linalg.init_tensor [%d0, %arg1] : tensor - %1 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : tensor) outs(%0 : tensor) { - ^bb0(%arg2: f32, %arg3 : f32): - linalg.yield %arg2 : f32 - } -> tensor - %2 = memref.dim %1, %c0 : tensor - %3 = memref.dim %1, %c1 : tensor - return %2, %3 : index, index -} -// CHECK: func @keep_result_dim_uses_sequence2 -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK: return %[[T0]], %[[ARG1]] - -// ----- - -#map = affine_map<(d0) -> (d0)> - -func @init_tensor_dim_of_linalg_result(%arg_0 : tensor, - %arg_1: tensor) -> (index, index) { - %0, %1 = linalg.generic { - indexing_maps = [#map, #map, #map], - iterator_types = ["parallel"] - } ins(%arg_0 : tensor) - outs(%arg_0, %arg_1 : tensor, tensor) { - ^bb0(%in: f32, %out_0: f32, %out_1: f32): - linalg.yield %in, %in : f32, f32 - } -> (tensor, tensor) - - %c0 = constant 0 : index - %num_elem_0 = memref.dim %0, %c0 : tensor - - %num_elem_1 = memref.dim %1, %c0 : tensor - return %num_elem_0, %num_elem_1 : index, index -} -// CHECK: func @init_tensor_dim_of_linalg_result( -// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor) -// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]] -// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]] -// CHECK: return %[[R0]], %[[R1]] - -// ----- - func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32> %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]] @@ -740,9 +541,12 @@ // CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> // CHECK: func @init_tensor_reshape_expansion // CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7] -// CHECK: return %[[T1]] +// CHECK: %[[C2:.+]] = constant 2 +// CHECK: %[[INIT1:.+]] = linalg.init_tensor [6, 5, %[[ARG0]]] +// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C2]] +// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: %[[INIT2:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7] +// CHECK: return %[[INIT2]] // ----- @@ -755,9 +559,12 @@ // CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> // CHECK: func @init_tensor_reshape_collapse // CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]] -// CHECK: return %[[T1]] +// CHECK: %[[C4:.+]] = constant 4 +// CHECK: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[ARG0]], 7] +// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C4]] +// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: %[[INIT2:.+]] = linalg.init_tensor [6, 5, %[[T0]]] +// CHECK: return %[[INIT2]] // ----- @@ -906,54 +713,6 @@ } : tensor<5x6xf32> to tensor<5x6xf32> return %0 : tensor<5x6xf32> } - -// ----- - -func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index) -{ - %c1 = constant 1 : index - %c3 = constant 3 : index - %c4 = constant 4 : index - %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] - : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> - %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32> - %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32> - %3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32> - return %1, %2, %3 : index, index, index -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> -// CHECK: func @dim_reshape_expansion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32> -// CHECK-DAG: %[[C2:.+]] = constant 2 : index -// CHECK-DAG: %[[C3:.+]] = constant 3 : index -// CHECK-DAG: %[[C4:.+]] = constant 4 : index -// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]] -// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] -// CHECK: return %[[C3]], %[[C4]], %[[D1]] - -// ----- - -func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index) -{ - %c1 = constant 1 : index - %c2 = constant 2 : index - %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]] - : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> - %1 = memref.dim %0, %c1 : tensor<6x5x?xf32> - %2 = memref.dim %0, %c2 : tensor<6x5x?xf32> - return %1, %2 : index, index -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> -// CHECK: func @dim_reshape_collapse -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32> -// CHECK-DAG: %[[C4:.+]] = constant 4 : index -// CHECK-DAG: %[[C5:.+]] = constant 5 : index -// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]] -// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] -// CHECK: return %[[C5]], %[[D1]] - -// ----- - func @propogate_casts(%arg0 : tensor, %arg1 : f32, %arg2 : index, %arg3 : index) -> tensor { %c0 = constant 0 : index @@ -1083,41 +842,6 @@ // ----- -func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index, - %arg3: f32) -> (index, index, index) -{ - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - %c3 = constant 3 : index - %c4 = constant 4 : index - %c5 = constant 5 : index - %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] { - ^bb0(%arg4: index, %arg5: index, %arg6: index): - linalg.yield %arg3 : f32 - } : tensor<2x?x?xf32> to tensor - %1 = memref.dim %0, %c0 : tensor - %2 = memref.dim %0, %c1 : tensor - %3 = memref.dim %0, %c2 : tensor - return %1, %2, %3 : index, index, index -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)> -// CHECK: func @dim_of_pad_op -// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32> -// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index -// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C2:.+]] = constant 2 : index -// CHECK-DAG: %[[C12:.+]] = constant 12 : index -// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]] -// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]] -// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]] -// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]] - -// ----- - #map = affine_map<(d0, d1) -> (d0, d1)> func @indexed_generic(%arg0: memref, %arg1: memref) { diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s module { func @three_op_fusion(%arg0: memref, %arg1: memref, diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s -// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP +// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP module { func @matmul_fusion(%A: tensor, %B: tensor, diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir @@ -0,0 +1,278 @@ +// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s + +func @init_tensor_static_dim() -> (index, index) { + %c0 = constant 0 : index + %c2 = constant 2 : index + %c6 = constant 6 : index + %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32> + %1 = memref.dim %0, %c2 : tensor<4x5x?xf32> + %2 = memref.dim %0, %c0 : tensor<4x5x?xf32> + return %1, %2 : index, index +} +// CHECK: func @init_tensor_static_dim +// CHECK-DAG: %[[C4:.+]] = constant 4 : index +// CHECK-DAG: %[[C6:.+]] = constant 6 : index +// CHECK: return %[[C6]], %[[C4]] + +// ----- + +func @init_tensor_dynamic_dim(%arg0 : index) -> (index) { + %c2 = constant 2 : index + %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32> + %1 = memref.dim %0, %c2 : tensor<4x5x?xf32> + return %1 : index +} +// CHECK: func @init_tensor_dynamic_dim +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG0]] + +// ----- + +func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.init_tensor [%arg0, %arg1] : tensor + %1 = memref.dim %0, %c0 : tensor + %2 = memref.dim %0, %c1 : tensor + return %1, %2 : index, index +} +// CHECK: func @init_tensor_dynamic_dim2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG0]], %[[ARG1]] + +// ----- + +func @remove_dim_result_uses + (%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %1 = mulf %arg3, %arg4 : f32 + %2 = addf %1, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor + %3 = memref.dim %0, %c0 : tensor + %4 = memref.dim %0, %c1 : tensor + return %3, %4 : index, index +} +// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 - s0)> +// CHECK: func @remove_dim_result_uses +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]] +// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]] +// CHECK: return %[[T2]], %[[T5]] + +// ----- + +func @remove_dim_result_uses_outs + (%arg0 : tensor, %arg1 : index) -> (index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = memref.dim %arg0, %c0 : tensor + %0 = linalg.init_tensor [%d0, %arg1] : tensor + %1 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%0 : tensor) { + ^bb0(%arg2: f32, %arg3: f32) : + linalg.yield %arg2 : f32 + } -> tensor + %2 = memref.dim %1, %c1 : tensor + return %2 : index +} +// CHECK: func @remove_dim_result_uses_outs +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG1]] + +// ----- + +func @remove_dim_result_uses_sequence + (%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> (index, index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %1 = memref.dim %0, %c0 : tensor + %2 = memref.dim %0, %c1 : tensor + %3 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2)>], + iterator_types = ["parallel", "reduction", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %4 = mulf %arg3, %arg4 : f32 + %5 = addf %4, %arg5 : f32 + linalg.yield %5 : f32 + } -> tensor + %6 = memref.dim %3, %c0 : tensor + %7 = memref.dim %3, %c1 : tensor + return %1, %2, %6, %7 : index, index, index, index +} +// CHECK-LABEL: func @remove_dim_result_uses_sequence +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]] + +// ----- + +func @keep_result_dim_uses_sequence2 + (%arg0 : tensor, %arg1 : index) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = memref.dim %arg0, %c0 : tensor + %0 = linalg.init_tensor [%d0, %arg1] : tensor + %1 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%0 : tensor) { + ^bb0(%arg2: f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } -> tensor + %2 = memref.dim %1, %c0 : tensor + %3 = memref.dim %1, %c1 : tensor + return %2, %3 : index, index +} +// CHECK: func @keep_result_dim_uses_sequence2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK: return %[[T0]], %[[ARG1]] + +// ----- + +#map = affine_map<(d0) -> (d0)> + +func @init_tensor_dim_of_linalg_result(%arg_0 : tensor, + %arg_1: tensor) -> (index, index) { + %0, %1 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel"] + } ins(%arg_0 : tensor) + outs(%arg_0, %arg_1 : tensor, tensor) { + ^bb0(%in: f32, %out_0: f32, %out_1: f32): + linalg.yield %in, %in : f32, f32 + } -> (tensor, tensor) + + %c0 = constant 0 : index + %num_elem_0 = memref.dim %0, %c0 : tensor + + %num_elem_1 = memref.dim %1, %c0 : tensor + return %num_elem_0, %num_elem_1 : index, index +} +// CHECK: func @init_tensor_dim_of_linalg_result( +// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor) +// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]] +// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]] +// CHECK: return %[[R0]], %[[R1]] + +// ----- + +func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index) +{ + %c1 = constant 1 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] + : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32> + %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32> + %3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32> + return %1, %2, %3 : index, index, index +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> +// CHECK: func @dim_reshape_expansion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32> +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK-DAG: %[[C4:.+]] = constant 4 : index +// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]] +// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: return %[[C3]], %[[C4]], %[[D1]] + +// ----- + +func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index) +{ + %c1 = constant 1 : index + %c2 = constant 2 : index + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]] + : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + %1 = memref.dim %0, %c1 : tensor<6x5x?xf32> + %2 = memref.dim %0, %c2 : tensor<6x5x?xf32> + return %1, %2 : index, index +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> +// CHECK: func @dim_reshape_collapse +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32> +// CHECK-DAG: %[[C4:.+]] = constant 4 : index +// CHECK-DAG: %[[C5:.+]] = constant 5 : index +// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]] +// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: return %[[C5]], %[[D1]] + +// ----- + +func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index, + %arg3: f32) -> (index, index, index) +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c5 = constant 5 : index + %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] { + ^bb0(%arg4: index, %arg5: index, %arg6: index): + linalg.yield %arg3 : f32 + } : tensor<2x?x?xf32> to tensor + %1 = memref.dim %0, %c0 : tensor + %2 = memref.dim %0, %c1 : tensor + %3 = memref.dim %0, %c2 : tensor + return %1, %2, %3 : index, index, index +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 4)> +// CHECK: func @dim_of_pad_op +// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32> +// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C12:.+]] = constant 12 : index +// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]] +// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]] +// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]] +// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]] diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -205,16 +205,14 @@ // CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)> // CHECK: #[[BOUND8_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 8, -d0 + s1)> -// CHECK: #[[BOUND8_MAP_3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 8)> // CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> // CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)> -// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> +// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 16, -d0 + s1)> // CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> // CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> -// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 4, -d0 + s1)> // CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, 2, -d1 + s1)> -// CHECK: #[[BOUND2_MAP_3:.+]] = affine_map<(d0, d1)[s0] -> (-d0 + s0, 2, -d1 + s0)> // CHECK: func @conv_tensors_dynamic // CHECK-SAME: (%[[INPUT]]: tensor, %[[FILTER]]: tensor, %[[ELEM]]: tensor) @@ -240,16 +238,20 @@ // CHECK-DAG: %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor // CHECK-DAG: %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor // CHECK-DAG: %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor +// CHECK-DAG: %[[FILL_N:.+]] = memref.dim %[[FILL]], %[[C0]] : tensor +// CHECK-DAG: %[[FILL_H:.+]] = memref.dim %[[FILL]], %[[C1]] : tensor +// CHECK-DAG: %[[FILL_W:.+]] = memref.dim %[[FILL]], %[[C2]] : tensor +// CHECK-DAG: %[[FILL_C:.+]] = memref.dim %[[FILL]], %[[C3]] : tensor // CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]]) // CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]] // CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[INPUT_N]], %[[ELEM_N]]] -// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_3]](%[[IV0]])[%[[ELEM_N]]] +// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[FILL_N]], %[[ELEM_N]]] // CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]] // CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]] // CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]]) // CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]] -// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[ELEM_OH]]] +// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]] // CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]] // CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]] // CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]] @@ -257,7 +259,7 @@ // CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]] // CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0] // CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]] -// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[ELEM_OW]]] +// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]] // CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]] // CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]] @@ -266,7 +268,7 @@ // CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILTER_OC]], %[[ELEM_OC]]] // CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]] // CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]] -// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_3]](%[[IV3]], %[[IV2]])[%[[ELEM_OC]]] +// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILL_C]], %[[ELEM_OC]]] // CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK-SAME: [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]] // CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s + +func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) + -> (index, index, index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %0:2 = "test.op_with_result_shape_interface"(%arg0, %arg1) + : (tensor<2x3x?xf32>, tensor) -> (tensor, tensor<2x3x?xf32>) + %1 = memref.dim %0#0, %c0 : tensor + %2 = memref.dim %0#0, %c1 : tensor + %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32> + %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32> + %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32> + return %1, %2, %3, %4, %5 : index, index, index, index, index +} +// CHECK-LABEL: func @result_shape_per_dim +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor) +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK-DAG: %[[C5:.+]] = constant 5 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]] +// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]] +// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]] +// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]] +// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]] +// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]] + +// ----- + +func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) + -> (index, index, index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1) + : (tensor<2x3x?xf32>, tensor) -> (tensor, tensor<2x3x?xf32>) + %1 = memref.dim %0#0, %c0 : tensor + %2 = memref.dim %0#0, %c1 : tensor + %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32> + %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32> + %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32> + return %1, %2, %3, %4, %5 : index, index, index, index, index +} +// CHECK-LABEL: func @result_shape_per_dim +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor) +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK-DAG: %[[C5:.+]] = constant 5 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]] +// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] + +// ----- + +func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) + -> (index, index, index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1) + : (tensor<2x3x?xf32>, tensor) -> (tensor, tensor<2x3x?xf32>) + %1 = memref.dim %0#0, %c0 : tensor + %2 = memref.dim %0#0, %c1 : tensor + %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32> + %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32> + %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32> + return %1, %2, %3, %4, %5 : index, index, index, index, index +} +// CHECK-LABEL: func @result_shape_and_per_dim +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor) +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK-DAG: %[[C5:.+]] = constant 5 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]] +// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]] +// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]] +// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]] +// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]] +// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]] diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir --- a/mlir/test/Transforms/test-canonicalize.mlir +++ b/mlir/test/Transforms/test-canonicalize.mlir @@ -82,30 +82,6 @@ return %0 : i32 } -// CHECK-LABEL: func @result_shape_per_dim -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>, %[[ARG_1:[a-z0-9]*]]: tensor) -func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) - -> (index, index, index, index, index) { - // CHECK-DAG: %[[C0:.+]] = constant 0 : index - // CHECK-DAG: %[[C2:.+]] = constant 2 : index - // CHECK-DAG: %[[C3:.+]] = constant 3 : index - // CHECK-DAG: %[[C5:.+]] = constant 5 : index - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1) - : (tensor<2x3x?xf32>, tensor) -> (tensor, tensor<2x3x?xf32>) - %1 = memref.dim %0#0, %c0 : tensor - %2 = memref.dim %0#0, %c1 : tensor - %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32> - %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32> - %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32> - // CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]] - // CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]] - // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] - return %1, %2, %3, %4, %5 : index, index, index, index, index -} - // CHECK-LABEL: test_dialect_canonicalizer func @test_dialect_canonicalizer() -> (i32) { %0 = "test.dialect_canonicalizable"() : () -> (i32) diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -63,6 +63,7 @@ MLIRReduce MLIRStandard MLIRStandardOpsTransforms + MLIRTensor MLIRTransformUtils MLIRTransforms ) diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -788,22 +789,75 @@ return success(); } +LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( + OpBuilder &builder, ValueRange operands, + llvm::SmallVectorImpl &shapes) { + Location loc = getLoc(); + shapes.reserve(operands.size()); + for (Value operand : llvm::reverse(operands)) { + auto currShape = llvm::to_vector<4>(llvm::map_range( + llvm::seq( + 0, operand.getType().cast().getRank()), + [&](int64_t dim) -> Value { + return builder.createOrFold(loc, operand, dim); + })); + shapes.push_back(builder.create( + getLoc(), builder.getIndexType(), currShape)); + } + return success(); +} + LogicalResult OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( OpBuilder &builder, llvm::SmallVectorImpl> &shapes) { - SmallVector operand1Shape, operand2Shape; Location loc = getLoc(); - for (auto i : - llvm::seq(0, operand1().getType().cast().getRank())) { - operand1Shape.push_back(builder.create(loc, operand1(), i)); + shapes.reserve(getNumOperands()); + for (Value operand : llvm::reverse(getOperands())) { + auto currShape = llvm::to_vector<4>(llvm::map_range( + llvm::seq( + 0, operand.getType().cast().getRank()), + [&](int64_t dim) -> Value { + return builder.createOrFold(loc, operand, dim); + })); + shapes.emplace_back(std::move(currShape)); } - for (auto i : - llvm::seq(0, operand2().getType().cast().getRank())) { - operand2Shape.push_back(builder.create(loc, operand2(), i)); + return success(); +} + +LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes( + OpBuilder &builder, ValueRange operands, + llvm::SmallVectorImpl &shapes) { + Location loc = getLoc(); + shapes.reserve(operands.size()); + for (Value operand : llvm::reverse(operands)) { + auto currShape = llvm::to_vector<4>(llvm::map_range( + llvm::seq( + 0, operand.getType().cast().getRank()), + [&](int64_t dim) -> Value { + return builder.createOrFold(loc, operand, dim); + })); + shapes.push_back(builder.create( + getLoc(), builder.getIndexType(), currShape)); + } + return success(); +} + +LogicalResult +OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( + OpBuilder &builder, + llvm::SmallVectorImpl> &shapes) { + Location loc = getLoc(); + shapes.reserve(getNumOperands()); + for (Value operand : llvm::reverse(getOperands())) { + auto currShape = llvm::to_vector<4>(llvm::map_range( + llvm::seq( + 0, operand.getType().cast().getRank()), + [&](int64_t dim) -> Value { + return builder.createOrFold(loc, operand, dim); + })); + shapes.emplace_back(std::move(currShape)); } - shapes.emplace_back(std::move(operand2Shape)); - shapes.emplace_back(std::move(operand1Shape)); return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -571,9 +571,25 @@ let results = (outs AnyTensor); } -def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface", +def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface", [DeclareOpInterfaceMethods]> { + ["reifyReturnTypeShapes"]>]> { + let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); + let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); +} + +def OpWithResultShapePerDimInterfaceOp : + TEST_Op<"op_with_result_shape_per_dim_interface", + [DeclareOpInterfaceMethods]> { + let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); + let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); +} + +def OpWithResultShapeAndPerDimInterfaceOp : + TEST_Op<"op_with_result_shape_and_per_dim_interface", + [DeclareOpInterfaceMethods]> { let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); }