diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -18,6 +18,7 @@ #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/RegionKindInterface.td" +include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -468,6 +469,7 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [ NoSideEffect, Terminator, + DeclareOpInterfaceMethods, HasParent<"ForeachThreadOp">, ] # GraphRegionNoTerminator.traits> { let summary = "terminates a `foreach_thread` block"; @@ -495,8 +497,9 @@ // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can // appear inside perform_concurrently. let extraClassDeclaration = [{ - SmallVector yieldedTypes(); - ::llvm::iterator_range yieldingOps(); + ::llvm::SmallVector<::mlir::Type> getYieldedTypes(); + ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps(); + ::mlir::OpResult getParentResult(int64_t idx); }]; } @@ -508,7 +511,9 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [ AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, - HasParent<"PerformConcurrentlyOp">]> { + // TODO: Cannot use an interface here atm, verify this manually for now. + // HasParent<"ParallelCombiningOpInterface"> + ]> { let summary = [{ Specify the tensor slice update of a single thread within the terminator of an `scf.foreach_thread`. @@ -568,6 +573,11 @@ return getSource().getType().cast(); } + ParallelCombiningOpInterface getParallelCombiningParent() { + return dyn_cast( + getOperation()->getParentOp()); + } + /// Return the expected rank of each of the `static_offsets`, `static_sizes` /// and `static_strides` attributes. std::array getArrayAttrMaxRanks() { @@ -599,6 +609,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) +add_mlir_interface(ParallelCombiningOpInterface) add_mlir_interface(SideEffectInterfaces) add_mlir_interface(TilingInterface) add_mlir_interface(VectorInterfaces) diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h @@ -0,0 +1,29 @@ +//===- ParallelCombiningOpInterface.h - Parallel combining op interface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the operation interface for ops that parallel combining +// operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_ +#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace detail { +// TODO: Single region single block interface on interfaces ? +LogicalResult verifyParallelCombiningOpInterface(Operation *op); +} // namespace detail +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/ParallelCombiningOpInterface.h.inc" + +#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td @@ -0,0 +1,75 @@ +//===- ParallelCombiningOpInterface.td - Parallel iface ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for ops that perform parallel combining operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE +#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE + +include "mlir/IR/OpBase.td" + +def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { + let description = [{ + A parallel combining op is an op with a region, that is not isolated from + above and yields values to its parent op without itself returning an SSA + value. The yielded values are determined by subvalues produced by the ops + contained in the region (the `yieldingOps`) and combined in any unspecified + order to produce the values yielded to the parent op. + + This is useful as a terminator to parallel operations that iterate over + some set and return tensors while avoiding tight coupling between the + iterating op, the combining op and the individual subtensor producing ops. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return `idx`^th result of the parent operation. + }], + /*retTy=*/"::mlir::OpResult", + /*methodName=*/"getParentResult", + /*args=*/(ins "int64_t":$idx), + /*methodBody=*/[{ + return $_op.getParentResult(idx); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the contained ops that yield subvalues that this op combines to + yield to its parent. + }], + /*retTy=*/"::llvm::iterator_range", + /*methodName=*/"getYieldingOps", + /*args=*/(ins), + /*methodBody=*/[{ + return $_op.getYieldingOps(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the contained ops that yield subvalues that this op combines to + yield to its parent. + }], + /*retTy=*/"::llvm::SmallVector<::mlir::Type>", + /*methodName=*/"getYieldedTypes", + /*args=*/(ins), + /*methodBody=*/[{ + return $_op.getYieldedTypes(); + }] + >, + ]; + // TODO: Single region single block interface on interfaces ? + let verify = [{ + return verifyParallelCombiningOpInterface($_op); + }]; +} + +#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -13,6 +13,7 @@ MLIRControlFlowDialect MLIRIR MLIRLoopLikeInterface + MLIRParallelCombiningOpInterface MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1061,7 +1061,7 @@ return emitOpError("region expects ") << getRank() << " arguments"; // Verify consistency between the result types and the terminator. - auto terminatorTypes = getTerminator().yieldedTypes(); + auto terminatorTypes = getTerminator().getYieldedTypes(); auto opResults = getResults(); if (opResults.size() != terminatorTypes.size()) return emitOpError("produces ") @@ -1182,7 +1182,7 @@ llvm::dyn_cast(bodyBlock.getTerminator()); assert(terminator && "expected bodyBuilder to create PerformConcurrentlyOp terminator"); - result.addTypes(terminator.yieldedTypes()); + result.addTypes(terminator.getYieldedTypes()); } // The ensureTerminator method generated by SingleBlockImplicitTerminator is @@ -1216,15 +1216,15 @@ //===----------------------------------------------------------------------===// OpResult ParallelInsertSliceOp::getTiedOpResult() { - auto foreachThreadOp = getOperation()->getParentOfType(); - assert(foreachThreadOp && "unlinked ParallelInsertSliceOp"); - PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator(); - for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) { + ParallelCombiningOpInterface parallelCombiningParent = + getParallelCombiningParent(); + for (const auto &it : + llvm::enumerate(parallelCombiningParent.getYieldingOps())) { Operation &nextOp = it.value(); if (&nextOp == getOperation()) - return foreachThreadOp->getResult(it.index()); + return parallelCombiningParent.getParentResult(it.index()); } - llvm_unreachable("ParallelInsertSliceOp not found"); + llvm_unreachable("ParallelInsertSliceOp no tied OpResult found"); } // Build a ParallelInsertSliceOp with mixed static and dynamic entries. @@ -1262,6 +1262,13 @@ build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +LogicalResult ParallelInsertSliceOp::verify() { + if (!isa(getOperation()->getParentOp())) + return this->emitError("expected ParallelCombiningOpInterface parent, got:") + << *(getOperation()->getParentOp()); + return success(); +} + namespace { /// Pattern to rewrite a parallel_insert_slice op with constant arguments. class ParallelInsertSliceOpConstantArgumentFolder final @@ -1382,15 +1389,19 @@ return success(); } -SmallVector PerformConcurrentlyOp::yieldedTypes() { +OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) { + return getOperation()->getParentOp()->getResult(idx); +} + +SmallVector PerformConcurrentlyOp::getYieldedTypes() { return llvm::to_vector<4>( - llvm::map_range(this->yieldingOps(), [](Operation &op) { + llvm::map_range(getYieldingOps(), [](Operation &op) { auto insertSliceOp = dyn_cast(&op); return insertSliceOp ? insertSliceOp.yieldedType() : Type(); })); } -llvm::iterator_range PerformConcurrentlyOp::yieldingOps() { +llvm::iterator_range PerformConcurrentlyOp::getYieldingOps() { return getRegion().front().getOperations(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1043,8 +1043,7 @@ if (&opOperand != &op->getOpOperand(1) /*dest*/) return {}; - // ParallelInsertSliceOp itself has no results. Tensors are returned via - // the parent op. + // ParallelInsertSliceOp itself has no results, query its tied op results. auto insertOp = cast(op); return {insertOp.getTiedOpResult()}; } @@ -1090,8 +1089,10 @@ // } OpBuilder::InsertionGuard g(rewriter); - auto insertOp = cast(op); - auto foreachThreadOp = insertOp->getParentOfType(); + auto parallelInsertSliceOp = cast(op); + ParallelCombiningOpInterface parallelCombiningParent = + parallelInsertSliceOp.getParallelCombiningParent(); + Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); // Nothing to do if the destination tensor is inplace. assert(state.isInPlace(op->getOpOperand(0) /*src*/) && @@ -1100,20 +1101,21 @@ return success(); // Find corresponding OpResult. - OpResult opResult = insertOp.getTiedOpResult(); + OpResult opResult = parallelInsertSliceOp.getTiedOpResult(); // Insert tensor allocation right before the ForeachThreadOp. - rewriter.setInsertionPoint(foreachThreadOp); + rewriter.setInsertionPoint(parallelIteratingOp); bool isYielded = state.isTensorYielded(opResult); - FailureOr alloc = - allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(), - /*escape=*/isYielded, state.getOptions()); + FailureOr alloc = allocateTensorForShapedValue( + rewriter, op->getLoc(), parallelInsertSliceOp.getDest(), + /*escape=*/isYielded, state.getOptions()); if (failed(alloc)) return failure(); // Update destination operand. - rewriter.updateRootInPlace( - insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); }); + rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() { + parallelInsertSliceOp.getDestMutable().assign(*alloc); + }); return success(); } @@ -1121,39 +1123,41 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); - auto insertOp = cast(op); - auto performConcurrentlyOp = cast(op->getParentOp()); - auto foreachThreadOp = - cast(performConcurrentlyOp->getParentOp()); + auto parallelInsertSliceOp = cast(op); + ParallelCombiningOpInterface parallelCombiningParent = + parallelInsertSliceOp.getParallelCombiningParent(); + Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); // Get destination buffer. FailureOr destBuffer = - getBuffer(rewriter, insertOp.getDest(), options); + getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); if (failed(destBuffer)) return failure(); - // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp. - rewriter.setInsertionPoint(performConcurrentlyOp); + // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`. + rewriter.setInsertionPoint(parallelCombiningParent); FailureOr srcBuffer = - getBuffer(rewriter, insertOp.getSource(), options); + getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); if (failed(srcBuffer)) return failure(); Value subview = rewriter.create( - insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(), - insertOp.getMixedSizes(), insertOp.getMixedStrides()); + parallelInsertSliceOp.getLoc(), *destBuffer, + parallelInsertSliceOp.getMixedOffsets(), + parallelInsertSliceOp.getMixedSizes(), + parallelInsertSliceOp.getMixedStrides()); // This memcpy will fold away if everything bufferizes in-place. - if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer, - subview))) + if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), + *srcBuffer, subview))) return failure(); - // Replace all uses of ForeachThreadOp (just the corresponding result). - rewriter.setInsertionPointAfter(foreachThreadOp); + // Replace all uses of parallelIteratingOp (just the corresponding result). + rewriter.setInsertionPointAfter(parallelIteratingOp); Value toTensorOp = - rewriter.create(foreachThreadOp.getLoc(), *destBuffer); + rewriter.create(parallelIteratingOp->getLoc(), *destBuffer); // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. - SmallVector resultUses = - llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(), - [](OpOperand &use) { return &use; })); + SmallVector resultUses = llvm::to_vector( + llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(), + [](OpOperand &use) { return &use; })); for (OpOperand *use : resultUses) { rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(toTensorOp); }); diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -8,6 +8,7 @@ InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp + ParallelCombiningOpInterface.cpp SideEffectInterfaces.cpp TilingInterface.cpp VectorInterfaces.cpp @@ -38,6 +39,7 @@ add_mlir_interface_library(DerivedAttributeOpInterface) add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) +add_mlir_interface_library(ParallelCombiningOpInterface) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) diff --git a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp @@ -0,0 +1,27 @@ +//===- ParallelCombiningOpInterface.cpp - Parallel combining op interface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/ParallelCombiningOpInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ParallelCombiningOpInterface +//===----------------------------------------------------------------------===// + +// TODO: Single region single block interface on interfaces ? +LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) { + if (op->getNumRegions() != 1) + return op->emitError("expected single region op"); + if (!op->getRegion(0).hasOneBlock()) + return op->emitError("expected single block op region"); + return success(); +} + +/// Include the definitions of the interface. +#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"