diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -76,9 +76,15 @@ void operator()(std::function fun = nullptr) { (*builder)(fun); } private: - typedef typename std::conditional::value, - AffineLoopNestBuilder, - LoopNestRangeBuilder>::type BuilderType; + using LoopOrAffineLoopBuilder = + typename std::conditional::value, + AffineLoopNestBuilder, + LoopNestRangeBuilder>::type; + using BuilderType = + typename std::conditional::value, + ParallelLoopNestBuilder, + LoopOrAffineLoopBuilder>::type; + std::unique_ptr builder; }; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -33,6 +33,10 @@ /// std.load/std.store accesses. std::unique_ptr> createConvertLinalgToLoopsPass(); +/// Create a pass to convert Linalg operations to loop.parallel loops and +/// std.load/std.store accesses. +std::unique_ptr> createConvertLinalgToParallelLoopsPass(); + /// Create a pass to convert Linalg operations to affine.for loops and /// affine_load/affine_store accesses. /// Placeholder for now, this is NYI. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -79,6 +79,10 @@ "if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " # " return matchFailure();">; +class LinalgOpToParallelLoops : NativeCodeCall< + "if (failed(linalgOpToParallelLoops<" # OpType # ">($_builder, op))) " # + " return matchFailure();">; + class LinalgOpToAffineLoops : NativeCodeCall< "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " # " return matchFailure();">; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -74,6 +74,10 @@ template LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op); +/// Emits a loop nest of `loop.parallel` with the proper body for `op`. +template +LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter, Operation *op); + /// Emits a loop nest of `affine.for` with the proper body for `op`. template LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h @@ -43,6 +43,10 @@ /// not an induction variable, then return nullptr. ForOp getForInductionVarOwner(Value val); +/// Returns the parallel loop parent of an induction variable. If the provided +// value is not an induction variable, then return nullptr. +ParallelOp getParallelForInductionVarOwner(Value val); + } // end namespace loop } // end namespace mlir #endif // MLIR_LOOPOPS_OPS_H_ diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -176,12 +176,18 @@ Variadic:$upperBound, Variadic:$step); let results = (outs Variadic:$results); - let regions = (region SizedRegion<1>:$body); + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps"> + ]; let extraClassDeclaration = [{ + Block *getBody() { return ®ion().front(); } iterator_range getInductionVars() { - Block &block = body().front(); - return {block.args_begin(), block.args_end()}; + return {getBody()->args_begin(), getBody()->args_end()}; } }]; } diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -155,6 +155,13 @@ static LoopBuilder makeAffine(ValueHandle *iv, ArrayRef lbHandles, ArrayRef ubHandles, int64_t step); + /// Constructs a new loop::ParallelOp and captures the associated induction + /// variables. An array of ValueHandle pointers is passed as the first + /// argument and is the *only* way to capture loop induction variables. + static LoopBuilder makeParallel(ArrayRef ivs, + ArrayRef lbHandles, + ArrayRef ubHandles, + ArrayRef steps); /// Constructs a new loop::ForOp and captures the associated induction /// variable. A ValueHandle pointer is passed as the first argument and is the /// *only* way to capture the loop induction variable. @@ -214,6 +221,18 @@ SmallVector loops; }; +class ParallelLoopNestBuilder { +public: + ParallelLoopNestBuilder(ArrayRef ivs, + ArrayRef lbs, ArrayRef ubs, + ArrayRef steps); + + void operator()(function_ref fun = nullptr); + +private: + SmallVector loops; +}; + /// Helper class to sugar building loop.for loop nests from ranges. /// This is similar to edsc::AffineLoopNestBuilder except it operates on /// loop.for. diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -272,7 +272,7 @@ } // Now copy over the contents of the body. - for (auto &op : parallelOp.body().front().without_terminator()) + for (auto &op : parallelOp.getBody()->without_terminator()) rewriter.clone(op, mapping); rewriter.eraseOp(parallelOp); diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -89,6 +89,7 @@ namespace mlir { namespace edsc { + template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( ArrayRef ivs, ArrayRef ranges) { @@ -105,12 +106,30 @@ assert(range.getType() && "expected linalg.range type"); assert(range.getDefiningOp() && "need operations to extract range parts"); RangeOp rangeOp = cast(range.getDefiningOp()); - lbs.emplace_back(ValueHandle(rangeOp.min())); - ubs.emplace_back(ValueHandle(rangeOp.max())); - steps.emplace_back(ValueHandle(rangeOp.step())); + lbs.emplace_back(rangeOp.min()); + ubs.emplace_back(rangeOp.max()); + steps.emplace_back(rangeOp.step()); } builder = std::make_unique(ivs, lbs, ubs, steps); } + +template <> +GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( + ArrayRef ivs, ArrayRef ranges) { + SmallVector lbs; + SmallVector ubs; + SmallVector steps; + for (Value range : ranges) { + assert(range.getType() && "expected linalg.range type"); + assert(range.getDefiningOp() && "need operations to extract range parts"); + RangeOp rangeOp = cast(range.getDefiningOp()); + lbs.emplace_back(rangeOp.min()); + ubs.emplace_back(rangeOp.max()); + steps.emplace_back(rangeOp.step()); + } + builder = std::make_unique(ivs, lbs, ubs, steps); +} + } // namespace edsc } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -565,6 +565,14 @@ LowerLinalgToLoopsPass>(); } +/// Create a pass to convert Linalg operations to loop.parallel loops and +/// std.load/std.store accesses. +std::unique_ptr> +mlir::createConvertLinalgToParallelLoopsPass() { + return std::make_unique< + LowerLinalgToLoopsPass>(); +} + /// Create a pass to convert Linalg operations to affine.for loops and /// affine_load/affine_store accesses. /// Placeholder for now, this is NYI. @@ -590,6 +598,14 @@ op, rewriter); } +// Emits a loop nest of `loop.parallel` with the proper body for `op`. +template +LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, + Operation *op) { + return LinalgOpToLoopsImpl::doit(op, rewriter); +} + // TODO(ntv) Need to make these instantiations more future-proof to avoid the // need to update as soon as we add new ops. #define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ @@ -607,11 +623,23 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) +// TODO(pifon): Enable lowering to parallel loops for ops other than +// linalg.generic for now to be on the safe side. +template LogicalResult +mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, + Operation *op); + static PassRegistration> structuredLoopsPass( "convert-linalg-to-loops", "Lower the operations from the linalg dialect into loops"); +static PassRegistration< + LowerLinalgToLoopsPass> + parallelLoopsPass( + "convert-linalg-to-parallel-loops", + "Lower the operations from the linalg dialect into parallel loops"); + static PassRegistration> affineLoopsPass( "convert-linalg-to-affine-loops", diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -225,6 +225,18 @@ // ParallelOp //===----------------------------------------------------------------------===// +void ParallelOp::build(Builder *builder, OperationState &result, ValueRange lbs, + ValueRange ubs, ValueRange steps) { + result.addOperands(lbs); + result.addOperands(ubs); + result.addOperands(steps); + Region *bodyRegion = result.addRegion(); + ForOp::ensureTerminator(*bodyRegion, *builder, result.location); + for (size_t i = 0; i < steps.size(); ++i) { + bodyRegion->front().addArgument(builder->getIndexType()); + } +} + static LogicalResult verify(ParallelOp op) { // Check that there is at least one value in lowerBound, upperBound and step. // It is sufficient to test only step, because it is ensured already that the @@ -242,7 +254,7 @@ // Check that the body defines the same number of block arguments as the // number of tuple elements in step. - Block *body = &op.body().front(); + Block *body = op.getBody(); if (body->getNumArguments() != stepValues.size()) return op.emitOpError( "expects the same number of induction variables as bound and step " @@ -322,15 +334,24 @@ static void print(OpAsmPrinter &p, ParallelOp op) { p << op.getOperationName() << " ("; - p.printOperands(op.body().front().getArguments()); + p.printOperands(op.getBody()->getArguments()); p << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step() << ")"; - p.printRegion(op.body(), /*printEntryBlockArgs=*/false); + p.printRegion(op.region(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict(op.getAttrs()); if (!op.results().empty()) p << " : " << op.getResultTypes(); } +ParallelOp mlir::loop::getParallelForInductionVarOwner(Value val) { + auto ivArg = val.dyn_cast(); + if (!ivArg) + return ParallelOp(); + assert(ivArg.getOwner() && "unlinked block argument"); + auto *containingInst = ivArg.getOwner()->getParentOp(); + return dyn_cast_or_null(containingInst); +} + //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -188,6 +188,24 @@ return result; } +mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeParallel( + ArrayRef ivs, ArrayRef lbHandles, + ArrayRef ubHandles, ArrayRef steps) { + mlir::edsc::LoopBuilder result; + auto opHandle = OperationHandle::create( + SmallVector(lbHandles.begin(), lbHandles.end()), + SmallVector(ubHandles.begin(), ubHandles.end()), + SmallVector(steps.begin(), steps.end())); + + loop::ParallelOp parallelOp = + cast(*opHandle.getOperation()); + for (size_t i = 0; i < ivs.size(); ++i) { + *ivs[i] = ValueHandle(parallelOp.getBody()->getArgument(i)); + } + result.enter(parallelOp.getBody(), /*prev=*/1); + return result; +} + mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeLoop(ValueHandle *iv, ValueHandle lbHandle, ValueHandle ubHandle, @@ -255,6 +273,29 @@ (*lit)(); } +mlir::edsc::ParallelLoopNestBuilder::ParallelLoopNestBuilder( + ArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps) { + assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); + assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); + assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); + + loops.emplace_back(LoopBuilder::makeParallel(ivs, lbs, ubs, steps)); +} + +void mlir::edsc::ParallelLoopNestBuilder::operator()( + function_ref fun) { + if (fun) + fun(); + // Iterate on the calling operator() on all the loops in the nest. + // The iteration order is from innermost to outermost because enter/exit needs + // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() + // occurs on calling operator()). The asymmetry is required for properly + // nesting imperfectly nested regions (see LoopBuilder::operator()). + for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) + (*lit)(); +} + mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef ivs, ArrayRef lbs, ArrayRef ubs, diff --git a/mlir/test/Dialect/Linalg/parallel_loops.mlir b/mlir/test/Dialect/Linalg/parallel_loops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/parallel_loops.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s -convert-linalg-to-parallel-loops -split-input-file | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @linalg_generic_sum(%lhs: memref<2x2xf32>, + %rhs: memref<2x2xf32>, + %sum: memref<2x2xf32>) { + linalg.generic { + args_in = 2 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"] + } %lhs, %rhs, %sum { + ^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32): // no predecessors + %0 = addf %lhs_in, %rhs_in : f32 + linalg.yield %0 : f32 + }: memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32> + return +} +// CHECK-LABEL: @linalg_generic_sum +// CHECK: %[[C2:.*]] = constant 2 +// CHECK: %[[C0:.*]] = constant 0 +// CHECK: %[[C1:.*]] = constant 1 +// CHECK: loop.parallel (%[[I:.*]], %[[J:.*]]) = {{.*}} +// CHECK: %[[LHS_ELEM:.*]] = load %{{.*}}[%[[I]], %[[J]]] +// CHECK: %[[RHS_ELEM:.*]] = load %{{.*}}[%[[I]], %[[J]]] +// CHECK: %[[SUM_ELEM:.*]] = load %{{.*}}[%[[I]], %[[J]]] +// CHECK: %[[SUM:.*]] = addf %[[LHS_ELEM]], %[[RHS_ELEM]] : f32 +// CHECK: store %[[SUM]], %{{.*}}[%[[I]], %[[J]]] +// CHECK: "loop.terminator"() : () -> ()