diff --git a/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h @@ -33,7 +33,9 @@ /// variable. A ValueHandle pointer is passed as the first argument and is the /// *only* way to capture the loop induction variable. LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle, - ValueHandle ubHandle, ValueHandle stepHandle); + ValueHandle ubHandle, ValueHandle stepHandle, + ArrayRef iter_args_handles = {}, + ArrayRef iter_args_init_values = {}); /// Helper class to sugar building loop.parallel loop nests from lower/upper /// bounds and step sizes. @@ -54,9 +56,13 @@ /// loop.for. class LoopNestBuilder { public: + LoopNestBuilder(edsc::ValueHandle *iv, ValueHandle lb, ValueHandle ub, + ValueHandle step, + ArrayRef iter_args_handles = {}, + ArrayRef iter_args_init_values = {}); LoopNestBuilder(ArrayRef ivs, ArrayRef lbs, ArrayRef ubs, ArrayRef steps); - void operator()(std::function fun = nullptr); + Operation::result_range operator()(std::function fun = nullptr); private: SmallVector loops; diff --git a/mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LoopOps/EDSC/Intrinsics.h @@ -0,0 +1,24 @@ +//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===// +//// +//// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +//// See https://llvm.org/LICENSE.txt for license information. +//// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//// +////===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_LOOPOPS_EDSC_INTRINSICS_H_ +#define MLIR_DIALECT_LOOPOPS_EDSC_INTRINSICS_H_ + +#include "mlir/Dialect/LoopOps/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace edsc { +namespace intrinsics { +using loop_yield = OperationBuilder; + +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_LOOPOPS_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -152,7 +152,6 @@ /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is /// scoped within a LoopBuilder. void operator()(function_ref fun = nullptr); - private: LoopBuilder() = default; @@ -166,7 +165,9 @@ ArrayRef steps); friend LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle, ValueHandle ubHandle, - ValueHandle stepHandle); + ValueHandle stepHandle, + ArrayRef iter_args_handles, + ArrayRef iter_args_init_values); }; // This class exists solely to handle the C++ vexing parse case when diff --git a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp --- a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp @@ -52,12 +52,32 @@ assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); } -void mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()( +mlir::edsc::LoopNestBuilder::LoopNestBuilder( + ValueHandle *iv, ValueHandle lb, ValueHandle ub, ValueHandle step, + ArrayRef iter_args_handles, + ArrayRef iter_args_init_values) { + assert(iter_args_init_values.size() == iter_args_handles.size() && + "expected size of arguments and argument_handles to match"); + loops.emplace_back(makeLoopBuilder(iv, lb, ub, step, iter_args_handles, + iter_args_init_values)); + assert(loops.size() == 1 && "Mismatch loops vs ivs size"); +} + +Operation::result_range +mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()( std::function fun) { if (fun) fun(); for (auto &lit : reverse(loops)) lit({}); + + // Get all the operations of the current block. + auto &operations = + ScopedContext::getBuilder().getInsertionBlock()->getOperations(); + // The for loop was the last op to be inserted in this block. + Operation &forloopop = operations.back(); + // Returns the results of the forloopop. + return forloopop.getResults(); } LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef ivs, @@ -78,15 +98,20 @@ return result; } -mlir::edsc::LoopBuilder mlir::edsc::makeLoopBuilder(ValueHandle *iv, - ValueHandle lbHandle, - ValueHandle ubHandle, - ValueHandle stepHandle) { +mlir::edsc::LoopBuilder +mlir::edsc::makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle, + ValueHandle ubHandle, ValueHandle stepHandle, + ArrayRef iter_args_handles, + ArrayRef iter_args_init_values) { mlir::edsc::LoopBuilder result; - auto forOp = - OperationHandle::createOp(lbHandle, ubHandle, stepHandle); + auto forOp = OperationHandle::createOp( + lbHandle, ubHandle, stepHandle, iter_args_init_values); *iv = ValueHandle(forOp.getInductionVar()); auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody(); + for (size_t i = 0; i < iter_args_handles.size(); ++i) { + // Skipping the induction variable. + *(iter_args_handles[i]) = ValueHandle(body->getArgument(i + 1)); + } result.enter(body, /*prev=*/1); return result; } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -10,7 +10,7 @@ #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" -#include "mlir/Dialect/LoopOps/EDSC/Builders.h" +#include "mlir/Dialect/LoopOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" #include "mlir/EDSC/Builders.h" @@ -1074,6 +1074,44 @@ f.erase(); } +TEST_FUNC(builder_loop_for_yield) { + auto indexType = IndexType::get(&globalContext()); + auto f32Type = FloatType::getF32(&globalContext()); + auto f = makeFunction("builder_loop_for_yield", {}, + {indexType, indexType, indexType, indexType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle init0 = std_constant_float(llvm::APFloat(1.0f), f32Type); + ValueHandle init1 = std_constant_float(llvm::APFloat(2.0f), f32Type); + ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)), + c(f.getArgument(2)), d(f.getArgument(3)); + ValueHandle arg0(f32Type); + ValueHandle arg1(f32Type); + using namespace edsc::op; + auto results = + LoopNestBuilder(&i, a - b, c + d, a, {&arg0, &arg1}, {init0, init1})([&] { + auto sum = arg0 + arg1; + loop_yield(ArrayRef{arg1, sum}); + }); + ValueHandle(results[0]) + ValueHandle(results[1]); + + // clang-format off + // CHECK-LABEL: func @builder_loop_for_yield(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { + // CHECK: [[init0:%.*]] = constant + // CHECK: [[init1:%.*]] = constant + // CHECK-DAG: [[r0:%[0-9]+]] = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%{{.*}}, %{{.*}}] + // CHECK-DAG: [[r1:%[0-9]+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%{{.*}}, %{{.*}}] + // CHECK-NEXT: [[res:%[0-9]+]]:2 = loop.for %{{.*}} = [[r0]] to [[r1]] step {{.*}} iter_args([[arg0:%.*]] = [[init0]], [[arg1:%.*]] = [[init1]]) -> (f32, f32) { + // CHECK: [[sum:%[0-9]+]] = addf [[arg0]], [[arg1]] : f32 + // CHECK: loop.yield [[arg1]], [[sum]] : f32, f32 + // CHECK: addf [[res]]#0, [[res]]#1 : f32 + + // clang-format on + f.print(llvm::outs()); + f.erase(); +} + int main() { RUN_TESTS(); return 0;