diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -108,9 +108,12 @@ %0 = bufferization.to_tensor %alloc restrict writable : memref<10xf32> ``` - Selected ops that bufferize to an allocation are also supported: + Selected ops that bufferize to an allocation (or need special handling) are + also supported: - `tensor.pad` is lowered to an allocation, followed by a `linalg.fill` and and a buffer copy (all on memrefs). + - `vector.mask` is bufferized together with its region. The allocation is + placed in front of the `vector.mask` op. An optional memory space attribute can be specified for the materialized buffer allocation. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -60,9 +60,34 @@ /// %0 = bufferization.to_tensor %alloc restrict writable /// /// In addition to rewriting the IR as shown above, this function returns the -/// newly allocated buffer. +/// newly allocated buffer. The `insertionPoint` parameter can be used to +/// specify a custom insertion point for the buffer allocation. Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp, - Attribute memorySpace = {}); + Attribute memorySpace = {}, + Operation *insertionPoint = nullptr); + +/// Materialize a buffer allocation for the given vector.mask op and bufferize +/// the op, including its region. E.g.: +/// +/// %0 = vector.mask { +/// vector.transfer_write %v, %t : vector<16xf32>, tensor +/// } : vector<16xi1> -> tensor +/// +/// is lowered to: +/// +/// %alloc = memref.alloc +/// memref.tensor_store %t, %subview +/// vector.mask { +/// vector.transfer_write %arg0, %alloc : vector<16xf32>, memref +/// } : vector<16xi1> +/// %0 = bufferization.to_tensor %alloc restrict writable +/// +/// In addition to rewriting the IR as shown above, this function returns the +/// newly allocated buffer. The `insertionPoint` parameter can be used to +/// specify a custom insertion point for the buffer allocation. +Value bufferizeToAllocation(RewriterBase &rewriter, vector::MaskOp maskOp, + Attribute memorySpace = {}, + Operation *insertionPoint = nullptr); /// Bufferize the given op with tensor semantics and materialize the result in /// a newly allocated buffer. @@ -72,10 +97,17 @@ /// supported. They are bufferized using their BufferizableOpInterface /// implementation. /// -/// Selected ops that bufferize to an allocation are also supported: +/// Selected ops that bufferize to an allocation (or need special handling) are +/// also supported: /// - tensor.pad +/// - vector.mask +/// +/// This function returns the newly allocated buffer. The `insertionPoint` +/// parameter can be used to specify a custom insertion point for the buffer +/// allocation. Value bufferizeToAllocation(RewriterBase &rewriter, Operation *op, - Attribute memorySpace = {}); + Attribute memorySpace = {}, + Operation *insertionPoint = nullptr); /// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a /// LinalgOp. This transforms looks for LinalgOps that have an unused output diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -170,15 +170,16 @@ } Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp, - Attribute memorySpace) { + Attribute memorySpace, + Operation *insertionPoint) { OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(padOp); + rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp); Location loc = padOp.getLoc(); // Create buffer allocation. Value alloc = createAllocationForTensor(rewriter, loc, padOp.getResult(), memorySpace); - rewriter.setInsertionPointAfter(alloc.getDefiningOp()); + rewriter.setInsertionPoint(padOp); // Create linalg.fill or linalg.generic. Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc); @@ -201,6 +202,66 @@ return alloc; } +Value linalg::bufferizeToAllocation(RewriterBase &rewriter, + vector::MaskOp maskOp, + Attribute memorySpace, + Operation *insertionPoint) { + assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 && + "expected single masked op"); + OpBuilder::InsertionGuard g(rewriter); + bufferization::BufferizationOptions options; + Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator(); + assert(isa(yieldOp) && "expected yield op terminator"); + + // Bufferize maskable op. By default, place the buffer allocation right before + // the mask op. + Value alloc = bufferizeToAllocation( + rewriter, maskOp.getMaskableOp(), memorySpace, + /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp); + + // Bufferize terminator. + rewriter.setInsertionPoint(yieldOp); + if (failed(cast(yieldOp).bufferize( + rewriter, options))) + return nullptr; + + // Erase dead to_tensor ops inside of the mask op. This is necessary because + // there only be one op (apart from the terminator) inside the mask op. + // TODO: Remove dead to_tensor ops more aggressively during bufferization. + SmallVector toTensorOps; + maskOp.walk([&](bufferization::ToTensorOp toTensorOp) { + if (toTensorOp->getUses().empty()) + toTensorOps.push_back(toTensorOp.getOperation()); + }); + for (Operation *op : toTensorOps) + rewriter.eraseOp(op); + + // Bufferize mask op. + SmallVector resultUses; + for (Value result : maskOp.getResults()) + if (isa(result.getType())) + for (OpOperand &use : result.getUses()) + resultUses.push_back(&use); + rewriter.setInsertionPoint(maskOp); + if (failed(cast(maskOp.getOperation()) + .bufferize(rewriter, options))) + return nullptr; + + // Set "restrict" attribute, indicating that no other tensor aliases with + // this tensor. That is because we just allocated a new buffer for the tensor. + for (OpOperand *resultUse : resultUses) { + auto toTensorOp = + resultUse->get().getDefiningOp(); + assert(toTensorOp && "expected to_tensor op"); + rewriter.updateRootInPlace(toTensorOp, [&]() { + toTensorOp.setRestrict(true); + toTensorOp.setWritable(true); + }); + } + + return alloc; +} + /// Lower tensor.from_elements to a sequence of chained tensor.insert. FailureOr mlir::linalg::rewriteInDestinationPassingStyle( RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { @@ -329,12 +390,15 @@ } Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Operation *op, - Attribute memorySpace) { + Attribute memorySpace, + Operation *insertionPoint) { using namespace bufferization; // Call specialized overload for certain ops. if (auto padOp = dyn_cast(op)) return bufferizeToAllocation(rewriter, padOp, memorySpace); + if (auto maskOp = dyn_cast(op)) + return bufferizeToAllocation(rewriter, maskOp, memorySpace); // Only bufferizable ops are supported. auto bufferizableOp = dyn_cast(op); @@ -386,7 +450,7 @@ // Allocate buffers. OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(op); + rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op); SmallVector allocs; for (OpOperand *operand : outOfPlaceOperands) { Value alloc = createAllocationForTensor(rewriter, op->getLoc(), @@ -401,6 +465,7 @@ } // Bufferize the op. + rewriter.setInsertionPoint(op); if (failed(bufferizableOp.bufferize(rewriter, options))) return nullptr; diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir --- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -100,3 +100,24 @@ // expected-error @below{{failed to bufferize operation}} %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @vector_mask( +// CHECK-SAME: %[[t:.*]]: tensor, +// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref +// CHECK: memref.tensor_store %[[t]], %[[alloc]] +// CHECK: vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %[[alloc]] +// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] restrict writable +// CHECK: memref.dealloc %[[alloc]] +// CHECK: return %[[r]] +func.func @vector_mask(%t: tensor, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor { + %r = vector.mask %m0 { vector.transfer_write %val, %t[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + return %r : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["vector.mask"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op +}