diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -503,115 +503,6 @@ }]; } -//===----------------------------------------------------------------------===// -// ParallelInsertSliceOp -//===----------------------------------------------------------------------===// - -// TODO: Implement PerformConcurrentlyOpInterface. -def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [ - AttrSizedOperandSegments, - OffsetSizeAndStrideOpInterface, - // TODO: Cannot use an interface here atm, verify this manually for now. - // HasParent<"ParallelCombiningOpInterface"> - ]> { - let summary = [{ - Specify the tensor slice update of a single thread within the terminator of - an `scf.foreach_thread`. - }]; - let description = [{ - The parent `scf.foreach_thread` returns values that are formed by aggregating - the actions of all the ops contained within the `perform_concurrently` - terminator of all the threads, in some unspecified order. - The `scf.foreach_thread.parallel_insert_slice` is one such op allowed in - the `scf.foreach_thread.perform_concurrently` terminator. - - Conflicting writes result in undefined semantics, in that the indices written - to by multiple parallel updates might contain data from any of the updates, or - even a malformed bit pattern. - - If an index is updated exactly once, the value contained at that index - in the resulting tensor will be equal to the value at a corresponding index of a - slice that was used for the updated. If an index is not updated at all, its value - will be equal to the one in the original tensor. - - This op does not create a new value, which allows maintaining a clean - separation between the subset and full tensor. - Note that we cannot mark this operation as pure (NoSideEffects), even - though it has no side effects, because it will get DCEd during - canonicalization. - }]; - - let arguments = (ins - AnyRankedTensor:$source, - AnyRankedTensor:$dest, - Variadic:$offsets, - Variadic:$sizes, - Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides - ); - let assemblyFormat = [{ - $source `into` $dest `` - custom($offsets, $static_offsets) - custom($sizes, $static_sizes) - custom($strides, $static_strides) - attr-dict `:` type($source) `into` type($dest) - }]; - - let extraClassDeclaration = [{ - ::mlir::Operation::operand_range offsets() { return getOffsets(); } - ::mlir::Operation::operand_range sizes() { return getSizes(); } - ::mlir::Operation::operand_range strides() { return getStrides(); } - ::mlir::ArrayAttr static_offsets() { return getStaticOffsets(); } - ::mlir::ArrayAttr static_sizes() { return getStaticSizes(); } - ::mlir::ArrayAttr static_strides() { return getStaticStrides(); } - - Type yieldedType() { return getDest().getType(); } - - RankedTensorType getSourceType() { - return getSource().getType().cast(); - } - - ParallelCombiningOpInterface getParallelCombiningParent() { - return dyn_cast( - getOperation()->getParentOp()); - } - - /// Return the expected rank of each of the `static_offsets`, `static_sizes` - /// and `static_strides` attributes. - std::array getArrayAttrMaxRanks() { - unsigned rank = getSourceType().getRank(); - return {rank, rank, rank}; - } - - /// Return the number of leading operands before `offsets`, `sizes` and - /// `strides` operands. - static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } - - /// Return the OpResult of the enclosing ForeachThreadOp that is - /// corresponding to this ParallelInsertSliceOp. - OpResult getTiedOpResult(); - }]; - - let builders = [ - // Build a ParallelInsertSliceOp with mixed static and dynamic entries. - OpBuilder<(ins "Value":$source, "Value":$dest, - "ArrayRef":$offsets, "ArrayRef":$sizes, - "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, - - // Build a ParallelInsertSliceOp with dynamic entries. - OpBuilder<(ins "Value":$source, "Value":$dest, - "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)> - ]; - - let hasCanonicalizer = 1; - let hasFolder = 1; - let hasVerifier = 1; -} - //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -17,6 +17,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -13,6 +13,7 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -1051,6 +1052,110 @@ let hasRegionVerifier = 1; } +//===----------------------------------------------------------------------===// +// ParallelInsertSliceOp +//===----------------------------------------------------------------------===// + +// TODO: Implement PerformConcurrentlyOpInterface. +def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [ + AttrSizedOperandSegments, + OffsetSizeAndStrideOpInterface, + // TODO: Cannot use an interface here atm, verify this manually for now. + // HasParent<"ParallelCombiningOpInterface"> + ]> { + let summary = [{ + Specify the tensor slice update of a single thread of a parent + ParallelCombiningOpInterface op. + }]; + let description = [{ + The `parallel_insert_slice` yields a subset tensor value to its parent + ParallelCombiningOpInterface. These subset tensor values are aggregated to + in some unspecified order into a full tensor value returned by the parent + parallel iterating op. + The `parallel_insert_slice` is one such op allowed in the + ParallelCombiningOpInterface op. + + Conflicting writes result in undefined semantics, in that the indices written + to by multiple parallel updates might contain data from any of the updates, + or even a malformed bit pattern. + + If an index is updated exactly once, the value contained at that index + in the resulting tensor will be equal to the value at a corresponding index + of a slice that was used for the updated. If an index is not updated at all, + its value will be equal to the one in the original tensor. + + This op does not create a new value, which allows maintaining a clean + separation between the subset and full tensor. + + Note that we cannot mark this operation as pure (NoSideEffects), even + though it has no side effects, because it will get DCEd during + canonicalization. + }]; + + let arguments = (ins + AnyRankedTensor:$source, + AnyRankedTensor:$dest, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + I64ArrayAttr:$static_offsets, + I64ArrayAttr:$static_sizes, + I64ArrayAttr:$static_strides + ); + let assemblyFormat = [{ + $source `into` $dest `` + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) + attr-dict `:` type($source) `into` type($dest) + }]; + + let extraClassDeclaration = [{ + Type yieldedType() { return getDest().getType(); } + + RankedTensorType getSourceType() { + return getSource().getType().cast(); + } + + ParallelCombiningOpInterface getParallelCombiningParent() { + return dyn_cast( + getOperation()->getParentOp()); + } + + /// Return the expected rank of each of the `static_offsets`, `static_sizes` + /// and `static_strides` attributes. + std::array getArrayAttrMaxRanks() { + unsigned rank = getSourceType().getRank(); + return {rank, rank, rank}; + } + + /// Return the number of leading operands before `offsets`, `sizes` and + /// `strides` operands. + static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } + + /// Return the OpResult of the enclosing ForeachThreadOp that is + /// corresponding to this ParallelInsertSliceOp. + OpResult getTiedOpResult(); + }]; + + let builders = [ + // Build a ParallelInsertSliceOp with mixed static and dynamic entries. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ArrayRef":$offsets, "ArrayRef":$sizes, + "ArrayRef":$strides, + CArg<"ArrayRef", "{}">:$attrs)>, + + // Build a ParallelInsertSliceOp with dynamic entries. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, + CArg<"ArrayRef", "{}">:$attrs)> + ]; + + let hasCanonicalizer = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -13,7 +13,6 @@ MLIRControlFlowDialect MLIRIR MLIRLoopLikeInterface - MLIRParallelCombiningOpInterface MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1211,137 +1211,6 @@ return dyn_cast(containingOp); } -//===----------------------------------------------------------------------===// -// ParallelInsertSliceOp -//===----------------------------------------------------------------------===// - -OpResult ParallelInsertSliceOp::getTiedOpResult() { - ParallelCombiningOpInterface parallelCombiningParent = - getParallelCombiningParent(); - for (const auto &it : - llvm::enumerate(parallelCombiningParent.getYieldingOps())) { - Operation &nextOp = it.value(); - if (&nextOp == getOperation()) - return parallelCombiningParent.getParentResult(it.index()); - } - llvm_unreachable("ParallelInsertSliceOp no tied OpResult found"); -} - -// Build a ParallelInsertSliceOp with mixed static and dynamic entries. -void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, - Value source, Value dest, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides, - ArrayRef attrs) { - SmallVector staticOffsets, staticSizes, staticStrides; - SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, - ShapedType::kDynamicStrideOrOffset); - dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, - ShapedType::kDynamicStrideOrOffset); - build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); - result.addAttributes(attrs); -} - -// Build a ParallelInsertSliceOp with dynamic entries. -void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, - Value source, Value dest, ValueRange offsets, - ValueRange sizes, ValueRange strides, - ArrayRef attrs) { - SmallVector offsetValues = llvm::to_vector<4>( - llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); - SmallVector sizeValues = llvm::to_vector<4>( - llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); - SmallVector strideValues = llvm::to_vector<4>( - llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); - build(b, result, source, dest, offsetValues, sizeValues, strideValues); -} - -LogicalResult ParallelInsertSliceOp::verify() { - if (!isa(getOperation()->getParentOp())) - return this->emitError("expected ParallelCombiningOpInterface parent, got:") - << *(getOperation()->getParentOp()); - return success(); -} - -namespace { -/// Pattern to rewrite a parallel_insert_slice op with constant arguments. -class ParallelInsertSliceOpConstantArgumentFolder final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp, - PatternRewriter &rewriter) const override { - // No constant operand, just return. - if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { - return matchPattern(operand, matchConstantIndex()); - })) - return failure(); - - // At least one of offsets/sizes/strides is a new constant. - // Form the new list of operands and constant attributes from the - // existing. - SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); - SmallVector mixedSizes(insertSliceOp.getMixedSizes()); - SmallVector mixedStrides(insertSliceOp.getMixedStrides()); - canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); - canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); - canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); - - // Create the new op in canonical form. - rewriter.replaceOpWithNewOp( - insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(), - mixedOffsets, mixedSizes, mixedStrides); - return success(); - } -}; -} // namespace - -/// Fold a parallel_insert_slice source coming from a tensor.cast op. -/// -/// Example: -/// ``` -/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { -/// %1 = compute_some_tensor() : tensor<64xf32> -/// %2 = tensor.cast %1 : tensor<64xf32> to tensor -/// scf.foreach_thread.perform_concurrently { -/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] : -/// tensor into tensor<128xf32> -/// } -/// } -/// ``` -/// -/// is folded into: -/// ``` -/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { -/// %1 = compute_some_tensor() : tensor<64xf32> -/// scf.foreach_thread.perform_concurrently { -/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] : -/// tensor<64xf32> into tensor<128xf32> -/// } -/// } -/// ``` -LogicalResult -ParallelInsertSliceOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - auto sourceCast = getSource().getDefiningOp(); - if (!sourceCast) - return failure(); - getSourceMutable().assign(sourceCast.getSource()); - return success(); -} - -void ParallelInsertSliceOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add(context); -} - //===----------------------------------------------------------------------===// // PerformConcurrentlyOp //===----------------------------------------------------------------------===// @@ -1355,10 +1224,12 @@ LogicalResult PerformConcurrentlyOp::verify() { // TODO: PerformConcurrentlyOpInterface. - for (const Operation &op : getRegion().front().getOperations()) - if (!isa(op)) - return emitOpError( - "expected only scf.foreach_thread.parallel_insert_slice ops"); + for (const Operation &op : getRegion().front().getOperations()) { + if (!isa(op)) { + return this->emitOpError("expected only ") + << tensor::ParallelInsertSliceOp::getOperationName() << " ops"; + } + } return success(); } @@ -1396,7 +1267,7 @@ SmallVector PerformConcurrentlyOp::getYieldedTypes() { return llvm::to_vector<4>( llvm::map_range(getYieldingOps(), [](Operation &op) { - auto insertSliceOp = dyn_cast(&op); + auto insertSliceOp = dyn_cast(&op); return insertSliceOp ? insertSliceOp.yieldedType() : Type(); })); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -927,7 +927,7 @@ getInsertionDest(ForeachThreadOp foreachThreadOp) { PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator(); SmallVector result; - terminator.walk([&](ParallelInsertSliceOp insertOp) { + terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) { result.push_back(&insertOp->getOpOperand(1) /*dest*/); }); return result; @@ -1004,248 +1004,6 @@ } }; -/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. -/// equivalent operand / result and same offset/sizes/strides specification). -static bool areEquivalentExtractSliceOps(const AnalysisState &state, - ExtractSliceOp st, - ParallelInsertSliceOp sti) { - if (!st || !sti) - return false; - if (st != sti && - !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) - return false; - if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) - return false; - return true; -} - -/// Return true if `value` is originating from an ExtractSliceOp that matches -/// the given InsertSliceOp. -static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, - ParallelInsertSliceOp insertOp) { - auto condition = [&](Value val) { - if (auto extractOp = val.getDefiningOp()) - if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) - return true; - return false; - }; - - return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), - condition); -} - -/// Analysis of ParallelInsertSliceOp. -struct ParallelInsertSliceOpInterface - : public BufferizableOpInterface::ExternalModel< - ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - if (&opOperand != &op->getOpOperand(1) /*dest*/) - return {}; - - // ParallelInsertSliceOp itself has no results, query its tied op results. - auto insertOp = cast(op); - return {insertOp.getTiedOpResult()}; - } - - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return &opOperand == &op->getOpOperand(1) /*dest*/; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } - - LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, - const AnalysisState &state) const { - // This interface method is overridden because we want to set a custom - // insertion point for tensor copies. They should be inserted right before - // the ForeachThreadOp. E.g.: - // - // %r0, %r1 = foreach_thead ... { - // ... - // perform_concurrently { - // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]} - // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]} - // } - // } - // - // After TensorCopyInsertion: - // - // %copy = bufferization.alloc_tensor() copy(%d) - // %r0, %r1 = foreach_thead ... { - // ... - // perform_concurrently { - // parallel_insert_slice %a into %b ... - // parallel_insert_slice %c into %copy ... - // } - // } - - OpBuilder::InsertionGuard g(rewriter); - auto parallelInsertSliceOp = cast(op); - ParallelCombiningOpInterface parallelCombiningParent = - parallelInsertSliceOp.getParallelCombiningParent(); - Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); - - // Nothing to do if the destination tensor is inplace. - assert(state.isInPlace(op->getOpOperand(0) /*src*/) && - "source is always in-place"); - if (state.isInPlace(op->getOpOperand(1) /*dest*/)) - return success(); - - // Find corresponding OpResult. - OpResult opResult = parallelInsertSliceOp.getTiedOpResult(); - - // Insert tensor allocation right before the ForeachThreadOp. - rewriter.setInsertionPoint(parallelIteratingOp); - bool isYielded = state.isTensorYielded(opResult); - FailureOr alloc = allocateTensorForShapedValue( - rewriter, op->getLoc(), parallelInsertSliceOp.getDest(), - /*escape=*/isYielded, state.getOptions()); - if (failed(alloc)) - return failure(); - - // Update destination operand. - rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() { - parallelInsertSliceOp.getDestMutable().assign(*alloc); - }); - - return success(); - } - - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { - OpBuilder::InsertionGuard g(rewriter); - auto parallelInsertSliceOp = cast(op); - ParallelCombiningOpInterface parallelCombiningParent = - parallelInsertSliceOp.getParallelCombiningParent(); - Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); - - // Get destination buffer. - FailureOr destBuffer = - getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); - if (failed(destBuffer)) - return failure(); - - // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`. - rewriter.setInsertionPoint(parallelCombiningParent); - FailureOr srcBuffer = - getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); - if (failed(srcBuffer)) - return failure(); - Value subview = rewriter.create( - parallelInsertSliceOp.getLoc(), *destBuffer, - parallelInsertSliceOp.getMixedOffsets(), - parallelInsertSliceOp.getMixedSizes(), - parallelInsertSliceOp.getMixedStrides()); - // This memcpy will fold away if everything bufferizes in-place. - if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), - *srcBuffer, subview))) - return failure(); - - // Replace all uses of parallelIteratingOp (just the corresponding result). - rewriter.setInsertionPointAfter(parallelIteratingOp); - Value toTensorOp = - rewriter.create(parallelIteratingOp->getLoc(), *destBuffer); - // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. - SmallVector resultUses = llvm::to_vector( - llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(), - [](OpOperand &use) { return &use; })); - for (OpOperand *use : resultUses) { - rewriter.updateRootInPlace(use->getOwner(), - [&]() { use->set(toTensorOp); }); - } - rewriter.eraseOp(op); - return success(); - } - - // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share - // the code. - bool isNotConflicting(Operation *op, OpOperand *uRead, - OpOperand *uConflictingWrite, - const AnalysisState &state) const { - Operation *readingOp = uRead->getOwner(); - Operation *conflictingWritingOp = uConflictingWrite->getOwner(); - - // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If - // uRead is an InsertSliceOp... - if (auto insertSliceOp = dyn_cast(readingOp)) { - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - - // TODO: Use insertSliceOp.getDestOpOperand etc. when available. - if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(state, uConflictingWrite->get(), - insertSliceOp)) - // Case 1: The main insight is that InsertSliceOp reads only part of - // the destination tensor. The overwritten area is not read. If - // uConflictingWrite writes into exactly the memory location that is - // being read by uRead, this is not a conflict. - // - // In the above example: - // uRead = OpOperand 1 (%t) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%0) of linalg.fill - // - // The read of %t does not conflict with the write of the FillOp - // (same aliases!) because the area that the FillOp operates on is - // exactly the one that is *not* read via %t. - return true; - - if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && - uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) - // Case 2: The read of the source tensor and the write to the dest - // tensor via an InsertSliceOp is not a conflict if the read is - // reading exactly that part of an equivalent tensor that the - // InsertSliceOp is writing. - // - // In the above example: - // uRead = OpOperand 0 (%1) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - return true; - } - - // If uConflictingWrite is an InsertSliceOp... - if (auto insertSliceOp = - dyn_cast(conflictingWritingOp)) - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - // %3 = vector.transfer_read %1, %cst - // - // In the above example: - // uRead = OpOperand 0 (%1) of vector.transfer_read - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - // lastWrite = %1 - // - // This is not a conflict because the InsertSliceOp overwrites the - // memory segment of %1 with the exact same data. (Effectively, there - // is no memory write here.) - if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - state.areEquivalentBufferizedValues(uRead->get(), - insertSliceOp.getSource()) && - hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), - insertSliceOp)) - return true; - - return false; - } -}; - } // namespace } // namespace scf } // namespace mlir @@ -1257,8 +1015,6 @@ ForOp::attachInterface(*ctx); IfOp::attachInterface(*ctx); ForeachThreadOp::attachInterface(*ctx); - ParallelInsertSliceOp::attachInterface( - *ctx); PerformConcurrentlyOp::attachInterface( *ctx); WhileOp::attachInterface(*ctx); diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -26,6 +26,7 @@ MLIRDialectUtils MLIRIR MLIRInferTypeOpInterface + MLIRParallelCombiningOpInterface MLIRSideEffectInterfaces MLIRSupport MLIRViewLikeInterface diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2179,6 +2179,137 @@ return {}; } +//===----------------------------------------------------------------------===// +// ParallelInsertSliceOp +//===----------------------------------------------------------------------===// + +OpResult ParallelInsertSliceOp::getTiedOpResult() { + ParallelCombiningOpInterface parallelCombiningParent = + getParallelCombiningParent(); + for (const auto &it : + llvm::enumerate(parallelCombiningParent.getYieldingOps())) { + Operation &nextOp = it.value(); + if (&nextOp == getOperation()) + return parallelCombiningParent.getParentResult(it.index()); + } + llvm_unreachable("ParallelInsertSliceOp no tied OpResult found"); +} + +// Build a ParallelInsertSliceOp with mixed static and dynamic entries. +void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, + Value source, Value dest, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides, + ArrayRef attrs) { + SmallVector staticOffsets, staticSizes, staticStrides; + SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, + ShapedType::kDynamicStrideOrOffset); + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, + ShapedType::kDynamicSize); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, + ShapedType::kDynamicStrideOrOffset); + build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes, + dynamicStrides, b.getI64ArrayAttr(staticOffsets), + b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + result.addAttributes(attrs); +} + +// Build a ParallelInsertSliceOp with dynamic entries. +void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, + Value source, Value dest, ValueRange offsets, + ValueRange sizes, ValueRange strides, + ArrayRef attrs) { + SmallVector offsetValues = llvm::to_vector<4>( + llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); + SmallVector sizeValues = llvm::to_vector<4>( + llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); + SmallVector strideValues = llvm::to_vector<4>( + llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); + build(b, result, source, dest, offsetValues, sizeValues, strideValues); +} + +LogicalResult ParallelInsertSliceOp::verify() { + if (!isa(getOperation()->getParentOp())) + return this->emitError("expected ParallelCombiningOpInterface parent, got:") + << *(getOperation()->getParentOp()); + return success(); +} + +namespace { +/// Pattern to rewrite a parallel_insert_slice op with constant arguments. +class ParallelInsertSliceOpConstantArgumentFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override { + // No constant operand, just return. + if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { + return matchPattern(operand, matchConstantIndex()); + })) + return failure(); + + // At least one of offsets/sizes/strides is a new constant. + // Form the new list of operands and constant attributes from the + // existing. + SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); + SmallVector mixedSizes(insertSliceOp.getMixedSizes()); + SmallVector mixedStrides(insertSliceOp.getMixedStrides()); + canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); + canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); + canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); + + // Create the new op in canonical form. + rewriter.replaceOpWithNewOp( + insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(), + mixedOffsets, mixedSizes, mixedStrides); + return success(); + } +}; +} // namespace + +/// Fold a parallel_insert_slice source coming from a tensor.cast op. +/// +/// Example: +/// ``` +/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { +/// %1 = compute_some_tensor() : tensor<64xf32> +/// %2 = tensor.cast %1 : tensor<64xf32> to tensor +/// scf.foreach_thread.perform_concurrently { +/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] : +/// tensor into tensor<128xf32> +/// } +/// } +/// ``` +/// +/// is folded into: +/// ``` +/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { +/// %1 = compute_some_tensor() : tensor<64xf32> +/// scf.foreach_thread.perform_concurrently { +/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] : +/// tensor<64xf32> into tensor<128xf32> +/// } +/// } +/// ``` +LogicalResult +ParallelInsertSliceOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + auto sourceCast = getSource().getDefiningOp(); + if (!sourceCast) + return failure(); + getSourceMutable().assign(sourceCast.getSource()); + return success(); +} + +void ParallelInsertSliceOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -810,6 +810,248 @@ } }; +/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. +/// equivalent operand / result and same offset/sizes/strides specification). +static bool areEquivalentExtractSliceOps(const AnalysisState &state, + ExtractSliceOp st, + ParallelInsertSliceOp sti) { + if (!st || !sti) + return false; + if (st != sti && + !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) + return false; + if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) + return false; + return true; +} + +/// Return true if `value` is originating from an ExtractSliceOp that matches +/// the given InsertSliceOp. +static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, + ParallelInsertSliceOp insertOp) { + auto condition = [&](Value val) { + if (auto extractOp = val.getDefiningOp()) + if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) + return true; + return false; + }; + + return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), + condition); +} + +/// Analysis of ParallelInsertSliceOp. +struct ParallelInsertSliceOpInterface + : public BufferizableOpInterface::ExternalModel< + ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + if (&opOperand != &op->getOpOperand(1) /*dest*/) + return {}; + + // ParallelInsertSliceOp itself has no results, query its tied op results. + auto insertOp = cast(op); + return {insertOp.getTiedOpResult()}; + } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return &opOperand == &op->getOpOperand(1) /*dest*/; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::Equivalent; + } + + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + // This interface method is overridden because we want to set a custom + // insertion point for tensor copies. They should be inserted right before + // the ForeachThreadOp. E.g.: + // + // %r0, %r1 = foreach_thead ... { + // ... + // perform_concurrently { + // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]} + // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]} + // } + // } + // + // After TensorCopyInsertion: + // + // %copy = bufferization.alloc_tensor() copy(%d) + // %r0, %r1 = foreach_thead ... { + // ... + // perform_concurrently { + // parallel_insert_slice %a into %b ... + // parallel_insert_slice %c into %copy ... + // } + // } + + OpBuilder::InsertionGuard g(rewriter); + auto parallelInsertSliceOp = cast(op); + ParallelCombiningOpInterface parallelCombiningParent = + parallelInsertSliceOp.getParallelCombiningParent(); + Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); + + // Nothing to do if the destination tensor is inplace. + assert(state.isInPlace(op->getOpOperand(0) /*src*/) && + "source is always in-place"); + if (state.isInPlace(op->getOpOperand(1) /*dest*/)) + return success(); + + // Find corresponding OpResult. + OpResult opResult = parallelInsertSliceOp.getTiedOpResult(); + + // Insert tensor allocation right before the ForeachThreadOp. + rewriter.setInsertionPoint(parallelIteratingOp); + bool isYielded = state.isTensorYielded(opResult); + FailureOr alloc = allocateTensorForShapedValue( + rewriter, op->getLoc(), parallelInsertSliceOp.getDest(), + /*escape=*/isYielded, state.getOptions()); + if (failed(alloc)) + return failure(); + + // Update destination operand. + rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() { + parallelInsertSliceOp.getDestMutable().assign(*alloc); + }); + + return success(); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + OpBuilder::InsertionGuard g(rewriter); + auto parallelInsertSliceOp = cast(op); + ParallelCombiningOpInterface parallelCombiningParent = + parallelInsertSliceOp.getParallelCombiningParent(); + Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); + + // Get destination buffer. + FailureOr destBuffer = + getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); + if (failed(destBuffer)) + return failure(); + + // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`. + rewriter.setInsertionPoint(parallelCombiningParent); + FailureOr srcBuffer = + getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); + if (failed(srcBuffer)) + return failure(); + Value subview = rewriter.create( + parallelInsertSliceOp.getLoc(), *destBuffer, + parallelInsertSliceOp.getMixedOffsets(), + parallelInsertSliceOp.getMixedSizes(), + parallelInsertSliceOp.getMixedStrides()); + // This memcpy will fold away if everything bufferizes in-place. + if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), + *srcBuffer, subview))) + return failure(); + + // Replace all uses of parallelIteratingOp (just the corresponding result). + rewriter.setInsertionPointAfter(parallelIteratingOp); + Value toTensorOp = + rewriter.create(parallelIteratingOp->getLoc(), *destBuffer); + // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. + SmallVector resultUses = llvm::to_vector( + llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(), + [](OpOperand &use) { return &use; })); + for (OpOperand *use : resultUses) { + rewriter.updateRootInPlace(use->getOwner(), + [&]() { use->set(toTensorOp); }); + } + rewriter.eraseOp(op); + return success(); + } + + // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share + // the code. + bool isNotConflicting(Operation *op, OpOperand *uRead, + OpOperand *uConflictingWrite, + const AnalysisState &state) const { + Operation *readingOp = uRead->getOwner(); + Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + + // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If + // uRead is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(readingOp)) { + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + + // TODO: Use insertSliceOp.getDestOpOperand etc. when available. + if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(state, uConflictingWrite->get(), + insertSliceOp)) + // Case 1: The main insight is that InsertSliceOp reads only part of + // the destination tensor. The overwritten area is not read. If + // uConflictingWrite writes into exactly the memory location that is + // being read by uRead, this is not a conflict. + // + // In the above example: + // uRead = OpOperand 1 (%t) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%0) of linalg.fill + // + // The read of %t does not conflict with the write of the FillOp + // (same aliases!) because the area that the FillOp operates on is + // exactly the one that is *not* read via %t. + return true; + + if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && + uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) + // Case 2: The read of the source tensor and the write to the dest + // tensor via an InsertSliceOp is not a conflict if the read is + // reading exactly that part of an equivalent tensor that the + // InsertSliceOp is writing. + // + // In the above example: + // uRead = OpOperand 0 (%1) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + return true; + } + + // If uConflictingWrite is an InsertSliceOp... + if (auto insertSliceOp = + dyn_cast(conflictingWritingOp)) + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + // %3 = vector.transfer_read %1, %cst + // + // In the above example: + // uRead = OpOperand 0 (%1) of vector.transfer_read + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + // lastWrite = %1 + // + // This is not a conflict because the InsertSliceOp overwrites the + // memory segment of %1 with the exact same data. (Effectively, there + // is no memory write here.) + if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + state.areEquivalentBufferizedValues(uRead->get(), + insertSliceOp.getSource()) && + hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), + insertSliceOp)) + return true; + + return false; + } +}; + } // namespace } // namespace tensor } // namespace mlir @@ -827,6 +1069,8 @@ GenerateOp::attachInterface(*ctx); InsertOp::attachInterface(*ctx); InsertSliceOp::attachInterface(*ctx); + ParallelInsertSliceOp::attachInterface( + *ctx); RankOp::attachInterface(*ctx); ReshapeOp::attachInterface(*ctx); }); diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1457,28 +1457,3 @@ // CHECK: ^[[bb3]](%[[z:.+]]: i64): // CHECK: "test.bar"(%[[z]]) // CHECK: return - -// ----- - -// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices( -// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor, -// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor, -// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index -func.func @canonicalize_parallel_insert_slice_indices( - %arg0 : tensor, %arg1: tensor, - %num_threads : index) -> tensor -{ - %cst = arith.constant 4.200000e+01 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - // CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor) { - // CHECK-NEXT: scf.foreach_thread.perform_concurrently { - // CHECK-NEXT: scf.foreach_thread.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1] - %2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor) { - scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor into tensor - } - } - return %2 : tensor -} diff --git a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir --- a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir +++ b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir @@ -1,36 +1,36 @@ -// RUN: mlir-opt %s -scf-for-loop-canonicalization -canonicalize | FileCheck %s +// RUN: mlir-opt %s -scf-for-loop-canonicalization | FileCheck %s -func.func @reduce() -> tensor<128xf32> { +func.func @reduce() { + // CHECK: %[[C64:.*]] = arith.constant 64 : index %c2 = arith.constant 2 : index - %cst = arith.constant dense<1.000000e+00> : tensor<1x128x384xf32> %cst_0 = arith.constant -0.000000e+00 : f32 - %0 = linalg.init_tensor [128, 384] : tensor<128x384xf32> - %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x384xf32>) -> tensor<128x384xf32> - %2 = linalg.init_tensor [128] : tensor<128xf32> - %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<128xf32>) -> tensor<128xf32> - %4 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { + %0 = memref.alloc() : memref<128x384xf32> + linalg.fill ins(%cst_0 : f32) outs(%0 : memref<128x384xf32>) + %2 = memref.alloc() : memref<128xf32> + linalg.fill ins(%cst_0 : f32) outs(%2 : memref<128xf32>) + scf.foreach_thread (%arg0) in (%c2) { %7 = affine.min affine_map<(d0) -> (d0 * -64 + 128, 64)>(%arg0) %8 = affine.max affine_map<(d0) -> (0, d0)>(%7) %9 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg0) %10 = affine.min affine_map<(d0, d1) -> (d1 * -64 + 128, d0)>(%8, %arg0) - // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}, 0] [64, 384] [1, 1] : tensor<128x384xf32> to tensor<64x384xf32> - // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [64] [1] : tensor<128xf32> to tensor<64xf32> - %11 = tensor.extract_slice %1[%9, 0] [%10, 384] [1, 1] : tensor<128x384xf32> to tensor - %12 = tensor.extract_slice %3[%9] [%10] [1] : tensor<128xf32> to tensor + // CHECK: memref.subview %{{.*}}[%{{.*}}, 0] [%[[C64]], 384] [1, 1] : memref<128x384xf32> to memref + // CHECK: memref.subview %{{.*}}[%{{.*}}] [%[[C64]]] [1] : memref<128xf32> to memref + %11 = memref.subview %0[%9, 0] [%10, 384] [1, 1] : + memref<128x384xf32> to memref (d0 * 384 + s0 + d1)>> + %12 = memref.subview %2[%9] [%10] [1] : + memref<128xf32> to memref (d0 + s0)>> - // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<64x384xf32>) outs(%{{.*}} : tensor<64xf32>) { - %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%11 : tensor) outs(%12 : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %14 = arith.addf %arg1, %arg2 : f32 - linalg.yield %14 : f32 - } -> tensor - - // CHECK-NOT: tensor.cast - // CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32> - scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor into tensor<128xf32> - } + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : memref) outs(%{{.*}} : memref) + linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%11 : memref (d0 * 384 + s0 + d1)>>) + outs(%12 : memref (d0 + s0)>>) { + ^bb0(%arg1: f32, %arg2: f32): + %14 = arith.addf %arg1, %arg2 : f32 + linalg.yield %14 : f32 + } } - return %4 : tensor<128xf32> + return } diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -531,7 +531,7 @@ %result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>, tensor<100xf32>) { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } } @@ -548,7 +548,7 @@ %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor) { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } } @@ -563,9 +563,9 @@ %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>) { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> - // expected-error @+1 {{expected only scf.foreach_thread.parallel_insert_slice ops}} + // expected-error @+1 {{expected only tensor.parallel_insert_slice ops}} scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> %0 = arith.constant 1: index } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -124,10 +124,10 @@ %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> { // CHECK: tensor.extract_slice // CHECK: scf.foreach_thread.perform_concurrently - // CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %[[alloc]] + // CHECK: tensor.parallel_insert_slice %{{.*}} into %[[alloc]] %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } // CHECK: } {thread_dim_mapping = [5]} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -537,7 +537,7 @@ // CHECK-NOT: scf.foreach_thread.perform_concurrently // CHECK-NOT: parallel_insert_slice scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : + tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : tensor into tensor } } @@ -589,7 +589,7 @@ // CHECK-NOT: scf.foreach_thread.perform_concurrently // CHECK-NOT: parallel_insert_slice scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : + tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : tensor into tensor } } @@ -627,7 +627,7 @@ // CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>) %8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32> scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32> + tensor.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32> } } return %0 : tensor<8x8xf32> diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -319,14 +319,14 @@ // CHECK: scf.foreach_thread // CHECK-NEXT: tensor.extract_slice // CHECK-NEXT: scf.foreach_thread.perform_concurrently - // CHECK-NEXT: scf.foreach_thread.parallel_insert_slice + // CHECK-NEXT: tensor.parallel_insert_slice // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1425,3 +1425,28 @@ // CHECK: return %[[E]] : tensor<16xf32> return %1 : tensor<16xf32> } + +// ----- + +// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices( +// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor, +// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor, +// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index +func.func @canonicalize_parallel_insert_slice_indices( + %arg0 : tensor, %arg1: tensor, + %num_threads : index) -> tensor +{ + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor) { + // CHECK-NEXT: scf.foreach_thread.perform_concurrently { + // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1] + %2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor) { + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor into tensor + } + } + return %2 : tensor +}