diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -18,6 +18,7 @@ #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/RegionKindInterface.td" +include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -468,6 +469,7 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [ NoSideEffect, Terminator, + DeclareOpInterfaceMethods, HasParent<"ForeachThreadOp">, ] # GraphRegionNoTerminator.traits> { let summary = "terminates a `foreach_thread` block"; @@ -495,8 +497,9 @@ // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can // appear inside perform_concurrently. let extraClassDeclaration = [{ - SmallVector yieldedTypes(); - ::llvm::iterator_range yieldingOps(); + ::llvm::SmallVector<::mlir::Type> getYieldedTypes(); + ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps(); + ::mlir::OpResult getParentResult(int64_t idx); }]; } @@ -508,7 +511,9 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [ AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, - HasParent<"PerformConcurrentlyOp">]> { + // TODO: Cannot use an interface here atm, verify this manually for now. + // HasParent<"ParallelCombiningOpInterface"> + ]> { let summary = [{ Specify the tensor slice update of a single thread within the terminator of an `scf.foreach_thread`. @@ -599,6 +604,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) +add_mlir_interface(ParallelCombiningOpInterface) add_mlir_interface(SideEffectInterfaces) add_mlir_interface(TilingInterface) add_mlir_interface(VectorInterfaces) diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h @@ -0,0 +1,24 @@ +//===- ParallelCombiningOpInterface.h - Parallel combining op interface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the operation interface for ops that parallel combining +// operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_ +#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir {} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/ParallelCombiningOpInterface.h.inc" + +#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td @@ -0,0 +1,71 @@ +//===- ParallelCombiningOpInterface.td - Parallel iface ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for ops that perform parallel combining operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE +#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE + +include "mlir/IR/OpBase.td" + +def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { + let description = [{ + A parallel combining op is an op with a region, that is not isolated from + above and yields values to its parent op without itself returning an SSA + value. The yielded values are determined by subvalues produced by the ops + contained in the region (the `yieldingOps`) and combined in any unspecified + order to produce the values yielded to the parent op. + + This is useful as a terminator to parallel operations that iterate over + some set and return tensors while avoiding tight coupling between the + iterating op, the combining op and the individual subtensor producing ops. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return `idx`^th result of the parent operation. + }], + /*retTy=*/"::mlir::OpResult", + /*methodName=*/"getParentResult", + /*args=*/(ins "int64_t":$idx), + /*methodBody=*/[{ + return $_op.getParentResult(idx); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the contained ops that yield subvalues that this op combines to + yield to its parent. + }], + /*retTy=*/"::llvm::iterator_range", + /*methodName=*/"getYieldingOps", + /*args=*/(ins), + /*methodBody=*/[{ + return $_op.getYieldingOps(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the contained ops that yield subvalues that this op combines to + yield to its parent. + }], + /*retTy=*/"::llvm::SmallVector<::mlir::Type>", + /*methodName=*/"getYieldedTypes", + /*args=*/(ins), + /*methodBody=*/[{ + return $_op.getYieldedTypes(); + }] + >, + ]; +} + +#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -13,6 +13,7 @@ MLIRControlFlowDialect MLIRIR MLIRLoopLikeInterface + MLIRParallelCombiningOpInterface MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1061,7 +1061,7 @@ return emitOpError("region expects ") << getRank() << " arguments"; // Verify consistency between the result types and the terminator. - auto terminatorTypes = getTerminator().yieldedTypes(); + auto terminatorTypes = getTerminator().getYieldedTypes(); auto opResults = getResults(); if (opResults.size() != terminatorTypes.size()) return emitOpError("produces ") @@ -1182,7 +1182,7 @@ llvm::dyn_cast(bodyBlock.getTerminator()); assert(terminator && "expected bodyBuilder to create PerformConcurrentlyOp terminator"); - result.addTypes(terminator.yieldedTypes()); + result.addTypes(terminator.getYieldedTypes()); } // The ensureTerminator method generated by SingleBlockImplicitTerminator is @@ -1216,15 +1216,15 @@ //===----------------------------------------------------------------------===// OpResult ParallelInsertSliceOp::getTiedOpResult() { - auto foreachThreadOp = getOperation()->getParentOfType(); - assert(foreachThreadOp && "unlinked ParallelInsertSliceOp"); - PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator(); - for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) { + auto parallelCombiningParent = + cast(getOperation()->getParentOp()); + for (const auto &it : + llvm::enumerate(parallelCombiningParent.getYieldingOps())) { Operation &nextOp = it.value(); if (&nextOp == getOperation()) - return foreachThreadOp->getResult(it.index()); + return parallelCombiningParent.getParentResult(it.index()); } - llvm_unreachable("ParallelInsertSliceOp not found"); + llvm_unreachable("ParallelInsertSliceOp no tied OpResult found"); } // Build a ParallelInsertSliceOp with mixed static and dynamic entries. @@ -1262,6 +1262,13 @@ build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +LogicalResult ParallelInsertSliceOp::verify() { + if (!isa(getOperation()->getParentOp())) + return this->emitError("expected ParallelCombiningOpInterface parent, got:") + << *(getOperation()->getParentOp()); + return success(); +} + namespace { /// Pattern to rewrite a parallel_insert_slice op with constant arguments. class ParallelInsertSliceOpConstantArgumentFolder final @@ -1382,15 +1389,19 @@ return success(); } -SmallVector PerformConcurrentlyOp::yieldedTypes() { +OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) { + return getOperation()->getParentOp()->getResult(idx); +} + +SmallVector PerformConcurrentlyOp::getYieldedTypes() { return llvm::to_vector<4>( - llvm::map_range(this->yieldingOps(), [](Operation &op) { + llvm::map_range(getYieldingOps(), [](Operation &op) { auto insertSliceOp = dyn_cast(&op); return insertSliceOp ? insertSliceOp.yieldedType() : Type(); })); } -llvm::iterator_range PerformConcurrentlyOp::yieldingOps() { +llvm::iterator_range PerformConcurrentlyOp::getYieldingOps() { return getRegion().front().getOperations(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1043,14 +1043,7 @@ if (&opOperand != &op->getOpOperand(1) /*dest*/) return {}; - // ParallelInsertSliceOp itself has no results. Tensors are returned via - // the parent op. - auto foreachThreadOp = op->getParentOfType(); - assert(foreachThreadOp && - "could not find valid owner of parallel_insert_slice"); - - // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult - // of the parent ForeachThreadOp. + // The i-th ParallelInsertSliceOp result is returned via its enclosing op.. Block *block = op->getBlock(); unsigned int opIdx = 0; for (ParallelInsertSliceOp insertOp : @@ -1059,10 +1052,9 @@ break; ++opIdx; } - assert(opIdx < foreachThreadOp->getNumResults() && - "could not find op inside terminator op"); - - return {foreachThreadOp->getResult(opIdx)}; + auto parallelCombiningParent = + cast(op->getParentOp()); + return {parallelCombiningParent.getParentResult(opIdx)}; } bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, @@ -1107,7 +1099,9 @@ OpBuilder::InsertionGuard g(rewriter); auto insertOp = cast(op); - auto foreachThreadOp = insertOp->getParentOfType(); + auto parallelCombiningParent = + cast(op->getParentOp()); + Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); // Nothing to do if the destination tensor is inplace. assert(state.isInPlace(op->getOpOperand(0) /*src*/) && @@ -1119,7 +1113,7 @@ OpResult opResult = insertOp.getTiedOpResult(); // Insert tensor allocation right before the ForeachThreadOp. - rewriter.setInsertionPoint(foreachThreadOp); + rewriter.setInsertionPoint(parallelIteratingOp); bool isYielded = state.isTensorYielded(opResult); FailureOr alloc = allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(), @@ -1138,9 +1132,9 @@ const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto insertOp = cast(op); - auto performConcurrentlyOp = cast(op->getParentOp()); - auto foreachThreadOp = - cast(performConcurrentlyOp->getParentOp()); + auto parallelCombiningParent = + cast(op->getParentOp()); + Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); // Get destination buffer. FailureOr destBuffer = @@ -1148,8 +1142,8 @@ if (failed(destBuffer)) return failure(); - // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp. - rewriter.setInsertionPoint(performConcurrentlyOp); + // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`. + rewriter.setInsertionPoint(parallelCombiningParent); FailureOr srcBuffer = getBuffer(rewriter, insertOp.getSource(), options); if (failed(srcBuffer)) @@ -1162,10 +1156,10 @@ subview))) return failure(); - // Replace all uses of ForeachThreadOp (just the corresponding result). - rewriter.setInsertionPointAfter(foreachThreadOp); + // Replace all uses of parallelIteratingOp (just the corresponding result). + rewriter.setInsertionPointAfter(parallelIteratingOp); Value toTensorOp = - rewriter.create(foreachThreadOp.getLoc(), *destBuffer); + rewriter.create(parallelIteratingOp->getLoc(), *destBuffer); // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. SmallVector resultUses = llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(), diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -8,6 +8,7 @@ InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp + ParallelCombiningOpInterface.cpp SideEffectInterfaces.cpp TilingInterface.cpp VectorInterfaces.cpp @@ -38,6 +39,7 @@ add_mlir_interface_library(DerivedAttributeOpInterface) add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) +add_mlir_interface_library(ParallelCombiningOpInterface) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) diff --git a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp @@ -0,0 +1,18 @@ +//===- ParallelCombiningOpInterface.cpp - Parallel combining op interface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/ParallelCombiningOpInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ParallelCombiningOpInterface +//===----------------------------------------------------------------------===// + +/// Include the definitions of the interface. +#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"