diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -547,7 +547,7 @@ tensor arguments. The semantics is that the `linalg.generic` op produces (i.e. allocates and fills) its return values. Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create + transformations can be applied. In particular, transformations which create control flow around linalg.generic operations are not expected to mix with tensors because SSA values do not escape naturally. Still, transformations and rewrites that take advantage of tensor SSA values are expected to be diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" namespace mlir { @@ -114,19 +115,15 @@ } /// Query the subset of input operands that are of ranked tensor type. SmallVector getInputTensorTypes() { - SmallVector res; - for (Type type : getInputs().getTypes()) - if (auto t = type.template dyn_cast()) - res.push_back(t); - return res; + return functional::map_if_non_null( + [](Type t) { return t.dyn_cast(); }, + getInputs().getTypes()); } /// Query the subset of output operands that are of ranked tensor type. SmallVector getOutputTensorTypes() { - SmallVector res; - for (Type type : getOutputs().getTypes()) - if (auto t = type.template dyn_cast()) - res.push_back(t); - return res; + return functional::map_if_non_null( + [](Type t) { return t.dyn_cast(); }, + getOutputs().getTypes()); } /// Return the range over outputs. Operation::operand_range getOutputs() { diff --git a/mlir/include/mlir/Support/Functional.h b/mlir/include/mlir/Support/Functional.h --- a/mlir/include/mlir/Support/Functional.h +++ b/mlir/include/mlir/Support/Functional.h @@ -44,6 +44,51 @@ return map(fun, std::begin(input), std::end(input)); } +/// Map with iterators. `fun` also acts like a filter: only mapped elements that +/// evaluate to true are copied. +template +auto map_if(Fn fun, IterType begin, IterType end, PredFn predicate) + -> SmallVector::type, 8> { + using R = typename std::result_of::type; + SmallVector res; + // auto i works with both pointer types and value types with an operator*. + // auto *i only works for pointer types. + for (auto i = begin; i != end; ++i) { + auto v = fun(*i); + if (predicate(v)) + res.push_back(v); + } + return res; +} + +/// Map with templated container. `fun` also acts like a filter: only mapped +/// elements that evaluate to true are copied. +template +auto map_if(Fn fun, ContainerType input, PredFn predicate) + -> decltype(map_if(fun, std::begin(input), std::end(input), predicate)) { + return map_if(fun, std::begin(input), std::end(input), predicate); +} + +/// Map. `fun` also acts like a filter: only mapped elements that evaluate to +/// true are copied. +template +auto map_if_non_null(Fn fun, IterType begin, IterType end) + -> SmallVector::type, 8> { + return map_if(fun, begin, end, + [](typename std::result_of::type v) { + return v != + typename std::result_of::type(); + }); +} + +/// Map with templated container. `fun` also acts like a filter: only mapped +/// elements that evaluate to true are copied. +template +auto map_if_non_null(Fn fun, ContainerType input) + -> decltype(map_if_non_null(fun, std::begin(input), std::end(input))) { + return map_if_non_null(fun, std::begin(input), std::end(input)); +} + /// Zip map with 2 templated container, iterates to the min of the sizes of /// the 2 containers. /// TODO(ntv): make variadic when needed.