diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -823,17 +823,19 @@ def PadOp : Op, ReportTrackingListenerFailuresOpTrait]> { let description = [{ Pads the operations pointed to by the target handle using the options - provides as operation attributes. + provides as operation attributes. The operation returns a handle to the + padded operation and to the padding operation ("tensor.pad"). #### Return modes This operation ignores non-Linalg ops and drops them in the return. This operation may produce a definiteFailure if the padding fails for any reason. + If all the operations referred to by the `target` handle pad properly, the transform succeeds. Otherwise the transform silently fails. The return handle points to only the subset of successfully produced @@ -849,11 +851,11 @@ DefaultValuedAttr< TypedArrayAttrBase, "{}">:$transpose_paddings); - let results = (outs TransformHandleTypeInterface:$transformed); + let results = (outs TransformHandleTypeInterface:$padded, + TransformHandleTypeInterface:$pad); let assemblyFormat = - "$target attr-dict `:` " - "custom(type($target), type($transformed))"; + "$target attr-dict `:` functional-type(operands, results)"; let hasVerifier = 1; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -364,19 +364,24 @@ void peelLoops(RewriterBase &rewriter, ArrayRef loops); /// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands -/// to a static bounding box. `padToMultipleOf` indicates that each padding -/// dimension should be padded to the specified multiple. If the derived padding -/// sizes should not be rounded up to any multiple, use "1". Use `paddingValues` -/// and `packPaddings` to set padding value and nofold attribute of the created -/// tensor::PadOps, respectively. Update `paddedOp` to the cloned operation with -/// statically shaped `paddingDimensions` and return the extracted dynamically -/// shaped results. If padding fails, return failure. -FailureOr> -rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, - ArrayRef paddingDimensions, - ArrayRef padToMultipleOf, - ArrayRef paddingValues, - ArrayRef packPaddings, LinalgOp &paddedOp); +/// to a static bounding box. The original `opToPad` is cloned and operates on +/// the padded tensors. +/// +/// * `padToMultipleOf` indicates that each padding dimension should be padded +/// to the specified multiple. If the derived padding sizes should not be +/// rounded up to any multiple, use "1". +/// * Use `paddingValues` and `packPaddings` to set padding value and nofold +/// attribute of the created tensor::PadOps, respectively. +/// * The unpadded results (extracted slice of the cloned operation) are +/// returned via `replacements`. +/// * The tensor::PadOps are returned via `padOps`. +LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, + ArrayRef paddingDimensions, + ArrayRef padToMultipleOf, + ArrayRef paddingValues, + ArrayRef packPaddings, LinalgOp &paddedOp, + SmallVector &replacements, + SmallVector &padOps); namespace detail { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1566,76 +1566,93 @@ //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PadOp::applyToOne(transform::TransformRewriter &rewriter, - LinalgOp target, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { - // Convert the integer packing flags to booleans. - SmallVector packPaddings; - for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) - packPaddings.push_back(static_cast(packPadding)); - - // Convert the padding values to attributes. - SmallVector paddingValues; - for (auto const &it : - llvm::zip(getPaddingValues(), target->getOperandTypes())) { - auto attr = dyn_cast(std::get<0>(it)); - if (!attr) { - emitOpError("expects padding values to be typed attributes"); - return DiagnosedSilenceableFailure::definiteFailure(); +transform::PadOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector paddedOps, padOps; + + for (Operation *target : state.getPayloadOps(getTarget())) { + auto linalgTarget = dyn_cast(target); + if (!linalgTarget) { + auto diag = emitSilenceableError() << "expected LinalgOp target"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; } - Type elementType = getElementTypeOrSelf(std::get<1>(it)); - // Try to parse string attributes to obtain an attribute of element type. - if (auto stringAttr = dyn_cast(attr)) { - auto parsedAttr = dyn_cast_if_present( - parseAttribute(stringAttr, getContext(), elementType, - /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); - if (!parsedAttr || parsedAttr.getType() != elementType) { - auto diag = this->emitOpError("expects a padding that parses to ") - << elementType << ", got " << std::get<0>(it); - diag.attachNote(target.getLoc()) << "when applied to this op"; + + // Convert the integer packing flags to booleans. + SmallVector packPaddings; + for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) + packPaddings.push_back(static_cast(packPadding)); + + // Convert the padding values to attributes. + SmallVector paddingValues; + for (auto const &it : + llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) { + auto attr = dyn_cast(std::get<0>(it)); + if (!attr) { + emitOpError("expects padding values to be typed attributes"); return DiagnosedSilenceableFailure::definiteFailure(); } - paddingValues.push_back(parsedAttr); - continue; - } - // Otherwise, add the attribute directly. - if (attr.getType() != elementType) { - auto diag = this->emitOpError("expects a padding value of type ") - << elementType << ", got " << attr; - diag.attachNote(target.getLoc()) << "when applied to this op"; - return DiagnosedSilenceableFailure::definiteFailure(); + Type elementType = getElementTypeOrSelf(std::get<1>(it)); + // Try to parse string attributes to obtain an attribute of element type. + if (auto stringAttr = dyn_cast(attr)) { + auto parsedAttr = dyn_cast_if_present(parseAttribute( + stringAttr, getContext(), elementType, + /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); + if (!parsedAttr || parsedAttr.getType() != elementType) { + auto diag = this->emitOpError("expects a padding that parses to ") + << elementType << ", got " << std::get<0>(it); + diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + paddingValues.push_back(parsedAttr); + continue; + } + // Otherwise, add the attribute directly. + if (attr.getType() != elementType) { + auto diag = this->emitOpError("expects a padding value of type ") + << elementType << ", got " << attr; + diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + paddingValues.push_back(attr); } - paddingValues.push_back(attr); - } - // Extract the transpose vectors. - SmallVector> transposePaddings; - for (Attribute transposeVector : cast(getTransposePaddings())) - transposePaddings.push_back( - extractFromI64ArrayAttr(cast(transposeVector))); + // Extract the transpose vectors. + SmallVector> transposePaddings; + for (Attribute transposeVector : cast(getTransposePaddings())) + transposePaddings.push_back( + extractFromI64ArrayAttr(cast(transposeVector))); + + LinalgOp paddedOp; + SmallVector paddingDimensions = + extractFromI64ArrayAttr(getPaddingDimensions()); + SmallVector padToMultipleOf(paddingDimensions.size(), 1); + if (getPadToMultipleOf().has_value()) + padToMultipleOf = extractFromI64ArrayAttr(*getPadToMultipleOf()); + SmallVector replacements; + SmallVector newPadOps; + if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, paddingDimensions, + padToMultipleOf, paddingValues, packPaddings, + paddedOp, replacements, newPadOps))) { + auto diag = emitSilenceableError() << "failed to pad op"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } - LinalgOp paddedOp; - SmallVector paddingDimensions = - extractFromI64ArrayAttr(getPaddingDimensions()); - SmallVector padToMultipleOf(paddingDimensions.size(), 1); - if (getPadToMultipleOf().has_value()) - padToMultipleOf = extractFromI64ArrayAttr(*getPadToMultipleOf()); - FailureOr> result = - rewriteAsPaddedOp(rewriter, target, paddingDimensions, padToMultipleOf, - paddingValues, packPaddings, paddedOp); - if (succeeded(result)) { // We need to perform our own replacement here because this API is still // used in patterns that "pad and hoist", for which the replacement values // need to be different. // TODO: clean this up and stop "pad and hoist" behavior more globally now // that we have more composable abstractions. - rewriter.replaceOp(target, *result); - results.push_back(paddedOp); - return DiagnosedSilenceableFailure::success(); + rewriter.replaceOp(linalgTarget, replacements); + paddedOps.push_back(paddedOp); + padOps.append(newPadOps.begin(), newPadOps.end()); } - return emitDefaultSilenceableFailure(target); + results.set(cast(getPadded()), paddedOps); + results.set(cast(getPad()), padOps); + return DiagnosedSilenceableFailure::success(); } LogicalResult transform::PadOp::verify() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -114,12 +114,12 @@ opOperand->get(), paddingValue, nofold); } -FailureOr> -linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, - ArrayRef paddingDimensions, - ArrayRef padToMultipleOf, - ArrayRef paddingValues, - ArrayRef packPaddings, LinalgOp &paddedOp) { +LogicalResult linalg::rewriteAsPaddedOp( + RewriterBase &rewriter, LinalgOp opToPad, + ArrayRef paddingDimensions, ArrayRef padToMultipleOf, + ArrayRef paddingValues, ArrayRef packPaddings, + LinalgOp &paddedOp, SmallVector &replacements, + SmallVector &padOps) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); Location loc = opToPad->getLoc(); @@ -147,6 +147,8 @@ "operand cannot be bound statically"); } newOperands.push_back(*paddedOperand); + if (auto padOp = paddedOperand->getDefiningOp()) + padOps.push_back(padOp); } ReifiedRankedShapedTypeDims reifiedResultShapes; @@ -184,13 +186,14 @@ // linalg op), so that the destination buffer of the computation does not // change. If the padding folds away, this will materizalize as a memcpy // between two identical buffers, which will then also fold away. - SmallVector copiedBack; + assert(paddedSubtensorResults.size() == opToPad.getNumDpsInits() && + "expected matching number of results"); for (auto it : llvm::zip(paddedSubtensorResults, opToPad.getDpsInitOperands())) { - copiedBack.push_back(rewriter.create( + replacements.push_back(rewriter.create( loc, std::get<0>(it), std::get<1>(it)->get())); } - return copiedBack; + return success(); } FailureOr @@ -206,10 +209,12 @@ if (options.padToMultipleOf.has_value()) padToMultipleOf.assign(options.padToMultipleOf->begin(), options.padToMultipleOf->end()); - FailureOr> newResults = rewriteAsPaddedOp( - rewriter, linalgOp, options.paddingDimensions, padToMultipleOf, - options.paddingValues, options.packPaddings, paddedOp); - if (failed(newResults)) + SmallVector newResults; + SmallVector padOps; + if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions, + padToMultipleOf, options.paddingValues, + options.packPaddings, paddedOp, newResults, + padOps))) return rewriter.notifyMatchFailure(linalgOp, "failed to rewrite as a padded op"); @@ -249,7 +254,7 @@ } // Replace the original operation to pad. - rewriter.replaceOp(linalgOp, *newResults); + rewriter.replaceOp(linalgOp, newResults); return paddedOp; } diff --git a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir @@ -0,0 +1,61 @@ + +// RUN: mlir-opt --test-transform-dialect-interpreter -cse -canonicalize -split-input-file -verify-diagnostics %s | FileCheck %s + +#map = affine_map<()[s0] -> (-s0 + 12, 7)> + +// CHECK-LABEL: func @pad_to_memory_space( +// CHECK-SAME: %[[arg0:.*]]: memref<24x12xf32, strided<[?, ?], offset: ?>>, +// CHECK-SAME: %[[arg1:.*]]: memref<12x25xf32, strided<[?, ?], offset: ?>>, +// CHECK-SAME: %[[arg2:.*]]: memref<24x25xf32, strided<[?, ?], offset: ?>>, +func.func @pad_to_memory_space(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, + %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map()[%iv2] + + // CHECK: %[[s0:.*]] = memref.subview %[[arg0]] + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + // CHECK: %[[s1:.*]] = memref.subview %[[arg1]] + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + // CHECK: %[[s2:.*]] = memref.subview %[[arg2]] + %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + + // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3> + // CHECK: linalg.fill {{.*}} outs(%[[alloc0]] + // CHECK: %[[alloc0_view:.*]] = memref.subview %[[alloc0]][0, 0] [4, %{{.*}}] [1, 1] + // CHECK: memref.copy %[[s0]], %[[alloc0_view]] + + // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3> + // CHECK: linalg.fill {{.*}} outs(%[[alloc1]] + // CHECK: %[[alloc1_view:.*]] = memref.subview %[[alloc1]][0, 0] [%{{.*}}, 5] [1, 1] + // CHECK: memref.copy %[[s1]], %[[alloc1_view]] + + // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3> + // CHECK: linalg.fill {{.*}} outs(%[[alloc2]] + // No subview because there is 0 padding + // CHECK: memref.copy %[[s2]], %[[alloc2]] + + // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}}) + // Copy back result. + // CHECK: memref.copy %[[alloc2]], %[[s2]] + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + + // insert_slice bufferizes to a no-op. + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + func.return %5 : tensor<24x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad %0 { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 1] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %pad_result = transform.get_result %pad[0] : (!transform.any_op) -> !transform.any_value + %buffer, %replacement = transform.structured.bufferize_to_allocation %pad_result {memory_space = 3} + %2 = transform.bufferization.one_shot_bufferize %arg1 {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op + +} diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -34,11 +34,11 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.pad %0 { + %padded, %pad = transform.structured.pad %0 { padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0] - } : (!transform.any_op) -> !transform.any_op + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } // ----- @@ -66,12 +66,12 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.pad %0 { + %padded, %pad = transform.structured.pad %0 { padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pad_to_multiple_of=[2, 2, 1], pack_paddings=[1, 1, 0] - } : (!transform.any_op) -> !transform.any_op + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } // ----- @@ -109,11 +109,11 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.pad %0 { + %padded, %pad = transform.structured.pad %0 { padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0] - } : (!transform.any_op) -> !transform.any_op + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } // ----- @@ -130,11 +130,11 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{op expects a padding value of type 'f32', got 0 : i32}} - %1 = transform.structured.pad %0 { + %padded, %pad = transform.structured.pad %0 { padding_values=[0: i32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0] - } : (!transform.any_op) -> !transform.any_op + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } // ----- @@ -151,11 +151,11 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{expects a padding that parses to 'f32', got "{foo}"}} - %1 = transform.structured.pad %0 { + %padded, %pad = transform.structured.pad %0 { padding_values=["{foo}", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0] - } : (!transform.any_op) -> !transform.any_op + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } // ----- @@ -175,11 +175,11 @@ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op // This error is silenceable and is not reported by this transform // {{transform.structured.pad failed to apply}} - %1 = transform.structured.pad %0 { + %padded, %pad = transform.structured.pad %0 { padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0] - } : (!transform.any_op) -> !transform.any_op + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } // ----- @@ -235,11 +235,11 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.pad %0 { + %padded, %pad = transform.structured.pad %0 { padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 1] - } : (!transform.any_op) -> !transform.any_op + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } // ----- @@ -286,9 +286,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.pad %0 { + %padded, %pad = transform.structured.pad %0 { padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 1] - } : (!transform.any_op) -> !transform.any_op + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) }