diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -837,8 +837,10 @@ } else { Value check = rewriter.create( loc, CmpIPredicate::eq, newLength, zero); - dynHasZeroLenCond = dynHasZeroLenCond - ? rewriter.create(loc, check, dynHasZeroLenCond) : check; + dynHasZeroLenCond = + dynHasZeroLenCond + ? rewriter.create(loc, check, dynHasZeroLenCond) + : check; } // The amount of high padding is simply the number of elements remaining, diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -61,8 +61,8 @@ // CHECK: %[[sizeC1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dC1]]] // CHECK: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [%[[sizeC0]], %[[sizeC1]]] [1, 1] : tensor to tensor // CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor, tensor) outs(%[[stC]] : tensor) -> tensor -// CHECK: %[[CAST:.*]] = tensor.cast %[[stD]] : tensor to tensor -// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[CAST]], %[[stB1]] : tensor, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[CAST:.*]] = tensor.cast %[[stD]] : tensor to tensor<2x4xf32> +// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[CAST]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]] // ----- @@ -279,3 +279,66 @@ // CHECK-SAME: outs(%[[ST_ARG]] : tensor) // CHECK: subtensor_insert %[[ST_ADD]] into %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]] + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @pad_generic_static +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C16:.*]] = constant 16 : index +// CHECK-DAG: %[[C32:.*]] = constant 32 : index +// CHECK-DAG: %[[C64:.*]] = constant 64 : index +// CHECK-DAG: %[[C128:.*]] = constant 128 : index +// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C64]] step %[[C16]] +// CHECK: %[[CMPI1:.*]] = cmpi eq +// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C128]] step %[[C32]] +// CHECK: %[[CMPI2:.*]] = cmpi eq +// CHECK: %[[HASZERO:.*]] = or %[[CMPI2]], %[[CMPI1]] : i1 +// CHECK: scf.if %[[HASZERO]] +// CHECK: tensor.generate +// CHECK: else +// CHECK: subtensor +// CHECK: linalg.pad_tensor +// CHECK: tensor.cast +// CHECK: subtensor +// CHECK: subtensor +// CHECK: linalg.generic +// CHECK: subtensor_insert +func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tensor<64x128xf32>) -> tensor<64x128xf32> { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c16 = constant 16 : index + %c32 = constant 32 : index + %zero = constant 0.0 : f32 + + %d0 = memref.dim %large_input, %c0 : tensor<64x128xf32> + %d1 = memref.dim %large_input, %c1 : tensor<64x128xf32> + + %pad = linalg.pad_tensor %small_input low[4, 60] high[2, 67] { + ^bb0(%arg0: index, %arg1: index): + linalg.yield %zero : f32 + } : tensor<58x1xf32> to tensor<64x128xf32> + + %fill = linalg.fill(%large_input, %zero) : tensor<64x128xf32>, f32 -> tensor<64x128xf32> + + %for0 = scf.for %iv0 = %c0 to %d0 step %c16 iter_args(%arg0 = %fill) -> tensor<64x128xf32> { + %for1 = scf.for %iv1 = %c0 to %d1 step %c32 iter_args(%arg1 = %arg0) -> tensor<64x128xf32> { + %0 = subtensor %pad[%iv0, %iv1][16, 32][1, 1] : tensor<64x128xf32> to tensor<16x32xf32> + %1 = subtensor %large_input[%iv0, %iv1][16, 32][1, 1] : tensor<64x128xf32> to tensor<16x32xf32> + %2 = subtensor %arg1[%iv0, %iv1][16, 32][1, 1] : tensor<64x128xf32> to tensor<16x32xf32> + + %add = linalg.generic + {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} + ins(%0, %1 : tensor<16x32xf32>, tensor<16x32xf32>) outs(%2 : tensor<16x32xf32>) { + ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): + %result = addf %arg4, %arg5 : f32 + linalg.yield %result : f32 + } -> tensor<16x32xf32> + + %insert = subtensor_insert %add into %arg1[%iv0, %iv1] [16, 32] [1, 1] : tensor<16x32xf32> into tensor<64x128xf32> + scf.yield %insert : tensor<64x128xf32> + } + scf.yield %for1 : tensor<64x128xf32> + } + return %for0 : tensor<64x128xf32> +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -235,9 +235,10 @@ MLIRContext *context = &getContext(); RewritePatternSet patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); - patterns.add(context); + patterns.add(context); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - while (succeeded(fuseLinalgOpsGreedily(getFunction()))) { + do { (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); PassManager pm(context); pm.addPass(createLoopInvariantCodeMotionPass()); @@ -246,7 +247,7 @@ LogicalResult res = pm.run(getFunction()->getParentOfType()); if (failed(res)) this->signalPassFailure(); - } + } while (succeeded(fuseLinalgOpsGreedily(getFunction()))); } };