diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -94,37 +94,63 @@ llvm_unreachable("Unsupported IterType"); } -/// A StructuredIndexed represents a captured value that can be indexed and -/// passed to the `makeGenericLinalgOp`. It allows writing intuitive index -/// expressions such as: +/// A StructuredIndexed represents an indexable quantity that is either: +/// 1. a captured value, which is suitable for buffer and tensor operands, or; +/// 2. a captured type, which is suitable for tensor return values. +/// +/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`. +/// It enable an idiomatic syntax for index expressions such as: /// /// ``` -/// StructuredIndexed A(vA), B(vB), C(vC); +/// StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value), +/// C(buffer_value_or_tensor_type); /// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); /// ``` -struct StructuredIndexed { - StructuredIndexed(Value v) : value(v) {} +struct StructuredIndexed : public ValueHandle { + StructuredIndexed(Type type) : ValueHandle(type) {} + StructuredIndexed(Value value) : ValueHandle(value) {} + StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {} StructuredIndexed operator()(ArrayRef indexings) { - return StructuredIndexed(value, indexings); + return StructuredIndexed(*this, indexings); } - operator Value() const /* implicit */ { return value; } ArrayRef getExprs() { return exprs; } private: + StructuredIndexed(Type t, ArrayRef indexings) + : ValueHandle(t), exprs(indexings.begin(), indexings.end()) { + assert(t.isa() && "RankedTensor expected"); + } StructuredIndexed(Value v, ArrayRef indexings) - : value(v), exprs(indexings.begin(), indexings.end()) { - assert(v.getType().isa() && "MemRefType expected"); + : ValueHandle(v), exprs(indexings.begin(), indexings.end()) { + assert((v.getType().isa() || + v.getType().isa()) && + "MemRef or RankedTensor expected"); } - StructuredIndexed(ValueHandle v, ArrayRef indexings) - : StructuredIndexed(v.getValue(), indexings) {} + StructuredIndexed(ValueHandle vh, ArrayRef indexings) + : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {} - Value value; SmallVector exprs; }; inline void defaultRegionBuilder(ArrayRef args) {} +/// Build a `linalg.generic` op with the specified `inputs`, `outputs` and +/// `region`. +/// +/// `otherValues` and `otherAttributes` may be passed and will be appended as +/// operands and attributes respectively. +/// +/// Prerequisites: +/// ============= +/// +/// 1. `inputs` may contain StructuredIndexed that capture either buffer or +/// tensor values. +/// 2. `outputs` may contain StructuredIndexed that capture either buffer values +/// or tensor types. If both buffer values and tensor types are present, then +/// all buffer values must appear before any tensor type. Without this +/// restriction output tensor results would need to be reordered, which would +/// result in surprising behavior when combined with region definition. Operation *makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, @@ -189,7 +215,7 @@ StructuredIndexed O); /// Build a linalg.pointwise with all `parallel` iterators and a region that -/// computes `O = max(I!, I2)`. The client is responsible for specifying the +/// computes `O = max(I1, I2)`. The client is responsible for specifying the /// proper indexings when creating the StructuredIndexed. Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O); diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -339,6 +339,7 @@ /// Implicit conversion useful for automatic conversion to Container. operator Value() const { return getValue(); } + operator Type() const { return getType(); } operator bool() const { return hasValue(); } /// Generic mlir::Op create. This is the key to being extensible to the whole diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -131,6 +131,10 @@ ArrayRef outputs, function_ref)> regionBuilder, ArrayRef otherValues, ArrayRef otherAttributes) { + for (unsigned i = 0, e = outputs.size(); i + 1 < e; ++i) + assert(!(outputs[i].getType().isa() && + outputs[i + 1].getType().isa()) && + "output tensors must be passed after output buffers"); auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); unsigned nInputs = inputs.size(); @@ -154,7 +158,11 @@ SmallVector values; values.reserve(nViews); values.append(inputs.begin(), inputs.end()); - values.append(outputs.begin(), outputs.end()); + std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(values), + [](StructuredIndexed s) { return s.hasValue(); }); + SmallVector types; + std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(types), + [](StructuredIndexed s) { return !s.hasValue(); }); auto iteratorStrTypes = functional::map(toString, iteratorTypes); // clang-format off @@ -162,7 +170,7 @@ edsc::ScopedContext::getBuilder() .create( edsc::ScopedContext::getLocation(), - ArrayRef{}, // TODO(ntv): support tensors + types, values, IntegerAttr::get(IntegerType::get(64, ctx), nInputs), IntegerAttr::get(IntegerType::get(64, ctx), nOutputs), @@ -210,6 +218,14 @@ StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); + if (O.getType().isa()) { + auto fun = [&unaryOp](ArrayRef args) { + assert(args.size() == 1 && "expected 1 block arguments"); + ValueHandle a(args[0]); + linalg_yield(unaryOp(a)); + }; + return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); + } auto fun = [&unaryOp](ArrayRef args) { assert(args.size() == 2 && "expected 2 block arguments"); ValueHandle a(args[0]); @@ -220,7 +236,6 @@ Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O) { - ; using edsc::intrinsics::tanh; UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); }); return linalg_pointwise(unOp, I, O); @@ -233,6 +248,14 @@ StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); + if (O.getType().isa()) { + auto fun = [&binaryOp](ArrayRef args) { + assert(args.size() == 2 && "expected 2 block arguments"); + ValueHandle a(args[0]), b(args[1]); + linalg_yield(binaryOp(a, b)); + }; + return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); + } auto fun = [&binaryOp](ArrayRef args) { assert(args.size() == 3 && "expected 3 block arguments"); ValueHandle a(args[0]), b(args[1]); diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -871,6 +871,49 @@ f.erase(); } +// clang-format off +// CHECK-LABEL: func @linalg_pointwise_mixed_tensors +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK: addf +// CHECK: }: tensor, memref -> tensor +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK: cmpf "ogt" +// CHECK: select +// CHECK: }: tensor, memref -> tensor +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK: tanh +// CHECK: }: tensor -> tensor +// clang-format on +TEST_FUNC(linalg_pointwise_mixed_tensors_test) { + using namespace edsc; + using namespace edsc::ops; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); + auto tensorType = RankedTensorType::get({-1, -1}, f32Type); + auto f = makeFunction("linalg_pointwise_mixed_tensors", {}, + {tensorType, memrefType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle A(f.getArgument(0)), B(f.getArgument(1)); + AffineExpr i, j; + bindDims(&globalContext(), i, j); + StructuredIndexed SA(A), SB(B), SC(tensorType); + linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_pointwise_tanh(SA({i, j}), SC({i, j})); + + f.print(llvm::outs()); + f.erase(); +} + // clang-format off // CHECK-LABEL: func @linalg_matmul // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,