diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -174,13 +174,35 @@ auto lowPad = padOp.getMixedLowPad(); auto highPad = padOp.getMixedHighPad(); SmallVector shapes; + + // Try to see if the source op implements the same interface. If so we can + // go a step further to avoid generating dim ops, which helps to compose + // affine expressions during result shape simplification and make dimensions + // static. + ReifiedRankedShapedTypeDims sourceResultShapes; + SmallVector sourceShape; + auto sourceIfxOp = dyn_cast_or_null( + padOp.source().getDefiningOp()); + if (sourceIfxOp && + succeeded(sourceIfxOp.reifyResultShapes(b, sourceResultShapes))) { + int resultIndex = padOp.source().cast().getResultNumber(); + sourceShape = sourceResultShapes[resultIndex]; + } + for (auto dim : llvm::seq(0, padOp.getSourceType().getRank())) { // Shape along each dimension is source dim + low pad + high pad. SmallVector mapOperands; - mapOperands.push_back( - b.createOrFold(loc, padOp.source(), dim)); + + Value sourceDim; + if (sourceShape.empty()) + sourceDim = b.createOrFold(loc, padOp.source(), dim); + else + sourceDim = sourceShape[dim]; + mapOperands.push_back(sourceDim); + AffineExpr expr = b.getAffineDimExpr(0); unsigned numSymbols = 0; + auto addOpFoldResult = [&](OpFoldResult valueOrAttr) { if (Value v = valueOrAttr.dyn_cast()) { expr = expr + b.getAffineSymbolExpr(numSymbols++); @@ -193,8 +215,29 @@ }; addOpFoldResult(lowPad[dim]); addOpFoldResult(highPad[dim]); - shapes.push_back(applyMapToValues( - b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); + + AffineMap map = AffineMap::get(1, numSymbols, expr); + fullyComposeAffineMapAndOperands(&map, &mapOperands); + canonicalizeMapAndOperands(&map, &mapOperands); + + // Handle the case where we have both dimensions and symbols and they map + // to the same value, e.g.: + // affine_map<(d0, s0) -> (d0 - s0 + 4)>(%v, %v). + // Due to the restrictions over dimensions and symbols, the above won't + // simplify. Try to change dimensions for symbols for such cases. + if (llvm::is_splat(mapOperands)) { + int numDims = map.getNumDims(); + int numSyms = map.getNumSymbols(); + DenseMap dimToSymMap; + for (int i = 0; i < numDims; ++i) { + dimToSymMap[b.getAffineDimExpr(i)] = + b.getAffineSymbolExpr(numSyms + i); + } + map = map.replace(dimToSymMap, /*numResultDims=*/0, + /*numResultSyms=*/numDims + numSyms); + } + + shapes.push_back(applyMapToValues(b, loc, map, mapOperands)[0]); } reifiedReturnShapes.emplace_back(std::move(shapes)); return success(); diff --git a/mlir/test/Dialect/Linalg/pad_fusion.mlir b/mlir/test/Dialect/Linalg/pad_fusion.mlir --- a/mlir/test/Dialect/Linalg/pad_fusion.mlir +++ b/mlir/test/Dialect/Linalg/pad_fusion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-linalg-pad-fusion -split-input-file %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-linalg-pad-fusion -cse %s | FileCheck %s func @dynamic_pad_fusion(%arg0 : tensor, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) -> tensor { @@ -33,9 +33,9 @@ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[SOURCE:.+]] = linalg.generic -// CHECK-DAG: %[[SOURCE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]] +// CHECK-DAG: %[[SOURCE_D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]] -// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] +// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] // CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[TARGET_D0]], %[[TARGET_D1]]] // CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG5]], %[[INIT]]) @@ -79,7 +79,7 @@ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[SOURCE:.+]] = linalg.generic -// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] +// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [49, %[[TARGET_D1]]] // CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG3]], %[[INIT]]) diff --git a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir --- a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir @@ -142,3 +142,74 @@ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index // CHECK: return %[[ARG1]], %[[ARG2]] + +// ----- + +func @pad_only_high_pad(%tensor: tensor<1x224x224x3xf32>, %arg0: index, %arg1: index) -> (index, index) { + %f0 = arith.constant 0.0 : f32 + %0 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %1 = affine.min affine_map<(d0) -> (d0 * 2 + 3, 224)>(%arg0) + %2 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 2)>(%1, %arg0) + %3 = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 3)>(%1, %arg0) + %4 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %5 = affine.min affine_map<(d0) -> (d0 * 2 + 9, 224)>(%arg1) + %6 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 2)>(%5, %arg1) + %7 = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 9)>(%5, %arg1) + %8 = tensor.extract_slice %tensor[0, %0, %4, 0][1, %2, %6, 3][1, 1, 1, 1] : tensor<1x224x224x3xf32> to tensor<1x?x?x3xf32> + + // Dim#1: %2 (source) + %3 (high pad) = (%1 - %arg0 * 2) + (-%1 + %arg0 * 2 + 3) = 3 + // Dim#2: %6 (source) + %7 (high pad) = (%5 - %arg1 * 2) + (-%5 + %arg1 * 2 + 9) = 9 + %pad = tensor.pad %8 low[0, 0, 0, 0] high[0, %3, %7, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %f0 : f32 + } : tensor<1x?x?x3xf32> to tensor<1x?x?x3xf32> + + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim1 = tensor.dim %pad, %c1 : tensor<1x?x?x3xf32> + %dim2 = tensor.dim %pad, %c2 : tensor<1x?x?x3xf32> + return %dim1, %dim2 : index, index +} + +// CHECK-LABEL: func @pad_only_high_pad +// CHECK: %[[C3:.+]] = arith.constant 3 : index +// CHECK: %[[C9:.+]] = arith.constant 9 : index +// CHECK: return %[[C3]], %[[C9]] + +// ----- + +func @pad_both_low_and_high_pad(%tensor: tensor<1x56x56x144xf32>, %arg0: index, %arg1: index, %arg2: index) -> (index, index) { + %f0 = arith.constant 0.0 : f32 + %0 = affine.max affine_map<(d0) -> (0, -d0 + 1)>(%arg0) + %1 = affine.max affine_map<(d0) -> (d0 - 1, 0)>(%arg0) + %2 = affine.min affine_map<(d0) -> (d0, 56)>(%1) + %3 = affine.max affine_map<(d0) -> (d0 + 3, 0)>(%arg0) + %4 = affine.min affine_map<(d0) -> (d0, 56)>(%3) + %5 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%4, %2) + %6 = affine.apply affine_map<(d0, d1, d2) -> (-d0 - d1 + d2 + 4)>(%0, %4, %2) + %7 = affine.max affine_map<(d0) -> (0, -d0 + 1)>(%arg1) + %8 = affine.max affine_map<(d0) -> (d0 - 1, 0)>(%arg1) + %9 = affine.min affine_map<(d0) -> (d0, 56)>(%8) + %10 = affine.max affine_map<(d0) -> (d0 + 3, 0)>(%arg1) + %11 = affine.min affine_map<(d0) -> (d0, 56)>(%10) + %12 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%11, %9) + %13 = affine.apply affine_map<(d0, d1, d2) -> (-d0 - d1 + d2 + 4)>(%7, %11, %9) + %14 = tensor.extract_slice %tensor[0, %2, %9, %arg2][1, %5, %12, 16][1, 1, 1, 1] : tensor<1x56x56x144xf32> to tensor<1x?x?x16xf32> + + // Dim#1: %0 (low pad) + %5 (source) + %6 (high pad) = %0 + (%4 - %2) + (-%0 - %4 + %2 + 4) = 4 + // Dim#1: %7 (low pad) + %12 (source) + %13 (high pad) = %7 + (%11 - %9) + (-%7 - %11 + %9 + 4) = 4 + %pad = tensor.pad %14 low[0, %0, %7, 0] high[0, %6, %13, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): // no predecessors + tensor.yield %f0 : f32 + } : tensor<1x?x?x16xf32> to tensor<1x?x?x16xf32> + + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim1 = tensor.dim %pad, %c1 : tensor<1x?x?x16xf32> + %dim2 = tensor.dim %pad, %c2 : tensor<1x?x?x16xf32> + return %dim1, %dim2 : index, index +} + +// CHECK-LABEL: func @pad_both_low_and_high_pad +// CHECK: %[[C4:.+]] = arith.constant 4 : index +// CHECK: return %[[C4]], %[[C4]]