diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -87,57 +87,6 @@ std::unique_ptr builder; }; -enum class IterType { Parallel, Reduction }; - -inline StringRef toString(IterType t) { - switch (t) { - case IterType::Parallel: - return getParallelIteratorTypeName(); - case IterType::Reduction: - return getReductionIteratorTypeName(); - } - llvm_unreachable("Unsupported IterType"); -} - -/// A StructuredIndexed represents an indexable quantity that is either: -/// 1. a captured value, which is suitable for buffer and tensor operands, or; -/// 2. a captured type, which is suitable for tensor return values. -/// -/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`. -/// It enable an idiomatic syntax for index expressions such as: -/// -/// ``` -/// StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value), -/// C(buffer_value_or_tensor_type); -/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); -/// ``` -struct StructuredIndexed : public ValueHandle { - StructuredIndexed(Type type) : ValueHandle(type) {} - StructuredIndexed(Value value) : ValueHandle(value) {} - StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {} - StructuredIndexed operator()(ArrayRef indexings) { - return StructuredIndexed(*this, indexings); - } - - ArrayRef getExprs() { return exprs; } - -private: - StructuredIndexed(Type t, ArrayRef indexings) - : ValueHandle(t), exprs(indexings.begin(), indexings.end()) { - assert(t.isa() && "RankedTensor expected"); - } - StructuredIndexed(Value v, ArrayRef indexings) - : ValueHandle(v), exprs(indexings.begin(), indexings.end()) { - assert((v.getType().isa() || - v.getType().isa()) && - "MemRef or RankedTensor expected"); - } - StructuredIndexed(ValueHandle vh, ArrayRef indexings) - : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {} - - SmallVector exprs; -}; - inline void defaultRegionBuilder(ArrayRef args) {} /// Build a `linalg.generic` op with the specified `inputs`, `outputs` and @@ -157,7 +106,7 @@ /// restriction output tensor results would need to be reordered, which would /// result in surprising behavior when combined with region definition. Operation *makeGenericLinalgOp( - ArrayRef iteratorTypes, ArrayRef inputs, + ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, function_ref)> regionBuilder = defaultRegionBuilder, diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -84,6 +84,19 @@ return res; } +/// Typed representation for loop type strings. +enum class IteratorType { Parallel, Reduction }; + +inline StringRef toString(IteratorType t) { + switch (t) { + case IteratorType::Parallel: + return getParallelIteratorTypeName(); + case IteratorType::Reduction: + return getReductionIteratorTypeName(); + } + llvm_unreachable("Unsupported IteratorType"); +} + } // end namespace mlir #endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -17,6 +17,7 @@ #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/Transforms/FoldUtils.h" @@ -493,6 +494,46 @@ mlir::Block *block; }; +/// A StructuredIndexed represents an indexable quantity that is either: +/// 1. a captured value, which is suitable for buffer and tensor operands, or; +/// 2. a captured type, which is suitable for tensor return values. +/// +/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`. +/// It enable an idiomatic syntax for index expressions such as: +/// +/// ``` +/// StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value), +/// C(buffer_value_or_tensor_type); +/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); +/// ``` +struct StructuredIndexed : public ValueHandle { + StructuredIndexed(Type type) : ValueHandle(type) {} + StructuredIndexed(Value value) : ValueHandle(value) {} + StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {} + StructuredIndexed operator()(ArrayRef indexings) { + return this->hasValue() ? StructuredIndexed(this->getValue(), indexings) + : StructuredIndexed(this->getType(), indexings); + } + + StructuredIndexed(Type t, ArrayRef indexings) + : ValueHandle(t), exprs(indexings.begin(), indexings.end()) { + assert(t.isa() && "RankedTensor expected"); + } + StructuredIndexed(Value v, ArrayRef indexings) + : ValueHandle(v), exprs(indexings.begin(), indexings.end()) { + assert((v.getType().isa() || + v.getType().isa()) && + "MemRef or RankedTensor expected"); + } + StructuredIndexed(ValueHandle vh, ArrayRef indexings) + : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {} + + ArrayRef getExprs() { return exprs; } + +private: + SmallVector exprs; +}; + template OperationHandle OperationHandle::create(Args... args) { return OperationHandle(ScopedContext::getBuilder() diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" @@ -144,7 +145,7 @@ } Operation *mlir::edsc::makeGenericLinalgOp( - ArrayRef iteratorTypes, ArrayRef inputs, + ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, function_ref)> regionBuilder, ArrayRef otherValues, ArrayRef otherAttributes) { @@ -240,8 +241,8 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) { - SmallVector iterTypes(O.getExprs().size(), - edsc::IterType::Parallel); + SmallVector iterTypes(O.getExprs().size(), + IteratorType::Parallel); if (O.getType().isa()) { auto fun = [&unaryOp](ArrayRef args) { assert(args.size() == 1 && "expected 1 block arguments"); @@ -270,8 +271,8 @@ StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O) { - SmallVector iterTypes(O.getExprs().size(), - edsc::IterType::Parallel); + SmallVector iterTypes(O.getExprs().size(), + IteratorType::Parallel); if (O.getType().isa()) { auto fun = [&binaryOp](ArrayRef args) { assert(args.size() == 2 && "expected 2 block arguments"); @@ -315,7 +316,7 @@ bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(vC); return makeGenericLinalgOp( - {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, {A({m, k}), B({k, n})}, {C({m, n})}, macRegionBuilder); @@ -329,7 +330,7 @@ bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(tC); return makeGenericLinalgOp( - {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, {A({m, k}), B({k, n})}, {C({m, n})}, mulRegionBuilder); @@ -343,7 +344,7 @@ bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(vC), D(tD); return makeGenericLinalgOp( - {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, {A({m, k}), B({k, n}), C({m, n})}, {D({m, n})}, macRegionBuilder); @@ -360,8 +361,8 @@ assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm"); // Some short names. - auto par = IterType::Parallel; - auto red = IterType::Reduction; + auto par = IteratorType::Parallel; + auto red = IteratorType::Reduction; auto s = strides; auto d = dilations; @@ -393,8 +394,8 @@ assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm"); // Some short names. - auto par = IterType::Parallel; - auto red = IterType::Reduction; + auto par = IteratorType::Parallel; + auto red = IteratorType::Reduction; auto s = strides; auto d = dilations;