diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -96,14 +96,6 @@ // Fail if `currOpOperand` is not defined by an ExtractSliceOp or EmptyOp. auto sliceOp = currOpOperand->get().getDefiningOp(); auto emptyOp = currOpOperand->get().getDefiningOp(); - if (!sliceOp && !emptyOp) { - // TODO: may want to add support for going through loop iter args. - // This is not strictly necessary as we can pad before hoisting but it would - // make the system more resilient to minor transformation reordering. - LLVM_DEBUG(DBGS() << "--not defined by an extractSlice or emptyOp\n"); - return rewriter.notifyMatchFailure( - opToPad, "not defined by an extractSlice or emptyOp"); - } llvm::SmallBitVector droppedDims; SmallVector mixedSizes; @@ -111,10 +103,19 @@ // Compute the dropped dimensions if `sliceOp` is rank-reducing. droppedDims = sliceOp.getDroppedDims(); mixedSizes = sliceOp.getMixedSizes(); - } - if (emptyOp) { + } else if (emptyOp) { mixedSizes = emptyOp.getMixedSizes(); droppedDims.resize(mixedSizes.size()); + } else if (hasStaticShape) { + mixedSizes = getAsIndexOpFoldResult(rewriter.getContext(), shape); + droppedDims.resize(mixedSizes.size()); + } else { + // TODO: may want to add support for going through loop iter args. + // This is not strictly necessary as we can pad before hoisting but it would + // make the system more resilient to minor transformation reordering. + LLVM_DEBUG(DBGS() << "--not defined by an extractSlice or emptyOp\n"); + return rewriter.notifyMatchFailure( + opToPad, "not defined by an extractSlice or emptyOp"); } LLVM_DEBUG(llvm::interleaveComma(mixedSizes, DBGS() << "--mixedSizes: "); llvm::dbgs() << "\n"); diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -127,6 +127,7 @@ // ----- +// CHECK-LABEL: @pad( func.func @pad(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { @@ -147,3 +148,63 @@ pack_paddings=[1, 1, 0] } } + +// ----- + +// Check that the padding can be applied even when the output argument of the +// linalg op is not produced by an empty op or an extract_slice op. + +// CHECK-DAG: #[[$MAP_MIN:.*]] = affine_map<(d0) -> (-d0 + 2044, 16)> +// CHECK-DAG: #[[$MAP_C0:.*]] = affine_map<() -> (0)> +// CHECK-DAG: #[[$MAP_TO_16:.*]] = affine_map<(d0) -> (-d0 + 16)> +// CHECK-LABEL: @outs_not_produced_by_empty_or_extract_slice( +// CHECK-SAME: %[[A:[^: ]*]]: tensor<128x2044xf32>, +// CHECK-SAME: %[[B:[^: ]*]]: tensor<2044x128xf32>) +func.func @outs_not_produced_by_empty_or_extract_slice(%a : tensor<128x2044xf32>, %b : tensor<2044x128xf32>) -> tensor<128x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x128xf32> + %9 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32> + + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c2044 = arith.constant 2044 : index + // CHECK: scf.for %[[ARG3:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %{{.*}}) + %10 = scf.for %arg3 = %c0 to %c2044 step %c16 iter_args(%arg4 = %9) -> (tensor<128x128xf32>) { + // CHECK: %[[MIN:.*]] = affine.min #[[$MAP_MIN]](%[[ARG3]]) + %11 = affine.min affine_map<(d0) -> (-d0 + 2044, 16)>(%arg3) + // CHECK: %[[A_SLICE:.*]] = tensor.extract_slice %[[A]] + // CHECK: %[[B_SLICE:.*]] = tensor.extract_slice %[[B]] + %extracted_slice_2 = tensor.extract_slice %a[0, %arg3] [128, %11] [1, 1] : tensor<128x2044xf32> to tensor<128x?xf32> + %extracted_slice_3 = tensor.extract_slice %b[%arg3, 0] [%11, 128] [1, 1] : tensor<2044x128xf32> to tensor + // CHECK-DAG: %[[CST:.*]] = arith.constant 0. + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + + // CHECK-DAG: %[[ZERO:.*]] = affine.apply #[[$MAP_C0]]() + // CHECK-DAG: %[[TO_16:.*]] = affine.apply #[[$MAP_TO_16]](%[[MIN]]) + // CHECK: %[[PADDED_A_SLICE:.*]] = tensor.pad %[[A_SLICE]] nofold low[%[[C0]], %[[C0]]] high[%[[ZERO]], %[[TO_16]]] + // CHECK: tensor.yield %[[CST]] + // CHECK: %[[PADDED_B_SLICE:.*]] = tensor.pad %[[B_SLICE]] nofold + // The output shape is already padded, so actually shouldn't + // add anything to the upper bound. + // CHECK: %[[ZERO0:.*]] = affine.apply #[[$MAP_C0]]() + // CHECK: %[[ZERO1:.*]] = affine.apply #[[$MAP_C0]]() + // CHECK: %[[PADDED_ARG4:.*]] = tensor.pad %[[ARG4]] nofold low[{{.*}}] high[%[[ZERO0]], %[[ZERO1]]] + + // CHECK: %[[T5:.*]] = linalg.matmul + // CHECK-SAME: ins(%[[PADDED_A_SLICE]], %[[PADDED_B_SLICE]] : tensor<128x16xf32>, tensor<16x128xf32>) + // CHECK-SAME: outs(%[[PADDED_ARG4]] : tensor<128x128xf32>) + %res = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<128x?xf32>, tensor) outs(%arg4 : tensor<128x128xf32>) -> tensor<128x128xf32> + scf.yield %res : tensor<128x128xf32> + } + return %10 : tensor<128x128xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.structured.pad %0 { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 1] + } +}