diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1766,6 +1766,8 @@ // Further described in docs/Rationale/RationaleTOSADialect.md . //===----------------------------------------------------------------------===// def Tosa_IfOp : Tosa_Op<"cond_if", [ + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> { let summary = "Conditional if operator"; diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h @@ -0,0 +1,155 @@ +//===-- ShapeUtils.h - TOSA shape support declarations ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Class declarations for shape utilities meant to assist shape propagation. +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_UTILS_SHAPE_UTILS_H +#define DIALECT_TOSA_UTILS_SHAPE_UTILS_H + +namespace mlir { +namespace tosa { +// Statically known information for a particular Value. +// +// This struct currently tracks only information relevant for tensor/array-like +// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped +// type as long as it is in the default "no knowledge" state returned by +// `getPessimisticValueState`. The important invariant is that we cannot +// claim to know something about a value which is false. +// +// This class could also be called "dataflow facts", "lattice value", etc. +struct ValueKnowledge { + ValueKnowledge() = delete; + ValueKnowledge(bool hasSizes, std::vector sizes, Type dtype) + : hasSizes(hasSizes), sizes(sizes), dtype(dtype) { + assert(sizes.size() == 0 || hasSizes); + } + + // Get the static knowledge intrinsic to `type`. + static ValueKnowledge getKnowledgeFromType(Type type) { + ValueKnowledge result = getPessimisticValueState(type.getContext()); + if (auto shapedType = type.dyn_cast()) { + if (shapedType.hasRank()) { + result.hasSizes = true; + result.sizes = shapedType.getShape(); + } + result.dtype = shapedType.getElementType(); + } + return result; + } + + // Return a pessimistic/conservative value state without assuming any knowlege + // about the IR. + static ValueKnowledge getPessimisticValueState(MLIRContext *context) { + return ValueKnowledge(false, {}, Type()); + } + + Type getType() const { + if (hasSizes) { + return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype); + } + return UnrankedTensorType::get(dtype); + } + + bool operator==(const ValueKnowledge &rhs) const { + return std::make_tuple(hasSizes, sizes, dtype) == + std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype); + } + + // Given two pieces of static knowledge, calculate conservatively the + // information we can be sure about. + static ValueKnowledge join(const ValueKnowledge &lhs, + const ValueKnowledge &rhs) { + // Mental model: All conditions are checking how to change from the safe "no + // knowledge" default-initialized state to a state with more knowledge + // consistent with lhs and rhs. + ValueKnowledge result = getPessimisticValueState(nullptr); + result.dtype = lhs.dtype == rhs.dtype ? lhs.dtype : Type(); + + if (!lhs.hasSizes && !rhs.hasSizes) { + return result; + } + + if (!rhs.hasSizes) { + result.hasSizes = true; + result.sizes = lhs.sizes; + return result; + } + + if (!lhs.hasSizes) { + result.hasSizes = true; + result.sizes = rhs.sizes; + return result; + } + + if (lhs.sizes.size() != rhs.sizes.size()) { + return result; + } + + result.hasSizes = true; + result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize); + for (int i = 0, e = result.sizes.size(); i != e; i++) { + int64_t lhsSize = lhs.sizes[i]; + int64_t rhsSize = rhs.sizes[i]; + int64_t &resultSize = result.sizes[i]; + if (lhsSize == ShapedType::kDynamicSize) { + resultSize = rhsSize; + } else if (rhsSize == ShapedType::kDynamicSize) { + resultSize = lhsSize; + } else if (lhsSize == rhsSize) { + resultSize = lhsSize; + } + } + + return result; + } + + // Given to types, generate a new ValueKnowledge that expands to cover both + // cases. E.g. if the rank of the LHS and RHS differ, the resulting tensor + // has unknown rank. + static ValueKnowledge expand(const ValueKnowledge &lhs, + const ValueKnowledge &rhs) { + ValueKnowledge result = getPessimisticValueState(nullptr); + result.dtype = lhs.dtype == rhs.dtype ? lhs.dtype : Type(); + + if (!lhs.hasSizes || !rhs.hasSizes) { + result.hasSizes = false; + return result; + } + + if (lhs.sizes.size() != rhs.sizes.size()) { + result.hasSizes = false; + return result; + } + + result.hasSizes = true; + result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize); + for (int i = 0, e = lhs.sizes.size(); i < e; i++) { + if (lhs.sizes[i] == rhs.sizes[i]) { + result.sizes[i] = lhs.sizes[i]; + } + } + + return result; + } + + // Whether the Value is known to have a list of sizes. + bool hasSizes; + // If `hasSizes`, the sizes along each rank. Unknown sizes are represented as + // `ShapedType::kDynamicSize`. + std::vector sizes; + // The dtype of a tensor. + // This is equal to nullptr if we don't know that it is a specific concrete + // type. + Type dtype; +}; +} // namespace tosa +} // namespace mlir + +#endif // DIALECT_TOSA_UTILS_SHAPE_UTILS_H diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/FoldUtils.h" @@ -901,6 +902,53 @@ return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); } +LogicalResult IfOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector yieldOps; + for (Region *region : regions) { + for (auto &block : *region) { + if (auto returnOp = dyn_cast(block.getTerminator())) { + yieldOps.push_back(returnOp); + } + } + } + + if (yieldOps.empty()) + return failure(); + + // Get the initial type information for the yield op. + llvm::SmallVector resultKnowledge; + resultKnowledge.reserve(yieldOps.front().getNumOperands()); + for (auto operand : yieldOps.front().getOperands()) { + resultKnowledge.push_back( + ValueKnowledge::getKnowledgeFromType(operand.getType())); + } + + for (auto yieldOp : yieldOps) { + if (resultKnowledge.size() != yieldOp.getNumOperands()) + return failure(); + + for (auto it : llvm::enumerate(yieldOp.getOperands())) { + int32_t index = it.index(); + resultKnowledge[index] = ValueKnowledge::expand( + resultKnowledge[index], + ValueKnowledge::getKnowledgeFromType(it.value().getType())); + } + } + + for (auto result : resultKnowledge) { + if (result.hasSizes) { + inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes)); + } else { + inferredReturnShapes.push_back(ShapedTypeComponents()); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -30,137 +31,55 @@ namespace { -// ----------------------------------------------------------------------------- -// Analysis. -// ----------------------------------------------------------------------------- - -static Type joinElementTypes(Type lhs, Type rhs) { - return lhs == rhs ? lhs : Type(); -} - -namespace { -// Statically known information for a particular Value. -// -// This struct currently tracks only information relevant for tensor/array-like -// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped -// type as long as it is in the default "no knowledge" state returned by -// `getPessimisticValueState`. The important invariant is that we cannot -// claim to know something about a value which is false. -// -// This class could also be called "dataflow facts", "lattice value", etc. -struct ValueKnowledge { - ValueKnowledge() = delete; - ValueKnowledge(bool hasSizes, std::vector sizes, Type dtype) - : hasSizes(hasSizes), sizes(sizes), dtype(dtype) { - assert(sizes.size() == 0 || hasSizes); - } - - // Get the static knowledge intrinsic to `type`. - static ValueKnowledge getKnowledgeFromType(Type type) { - ValueKnowledge result = getPessimisticValueState(type.getContext()); - if (auto shapedType = type.dyn_cast()) { - if (shapedType.hasRank()) { - result.hasSizes = true; - result.sizes = shapedType.getShape(); - } - result.dtype = shapedType.getElementType(); +void PropagateRegion(Region ®ion); + +void PropagateIfCond(Operation &op) { + tosa::IfOp ifOp = dyn_cast(op); + if (!ifOp) + return; + + for (auto ®ion : op.getRegions()) { + Block &frontBlock = region.front(); + if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) + return; + + for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { + ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( + ifOp.getOperand(i + 1).getType()); + ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType( + frontBlock.getArgument(i).getType()); + ValueKnowledge joinedKnowledge = + ValueKnowledge::join(operandKnowledge, blockKnowledge); + frontBlock.getArgument(i).setType(joinedKnowledge.getType()); } - return result; - } - // Return a pessimistic/conservative value state without assuming any knowlege - // about the IR. - static ValueKnowledge getPessimisticValueState(MLIRContext *context) { - return ValueKnowledge(false, {}, Type()); - } - - Type getType() const { - if (hasSizes) { - return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype); - } - return UnrankedTensorType::get(dtype); + PropagateRegion(region); } - bool operator==(const ValueKnowledge &rhs) const { - return std::make_tuple(hasSizes, sizes, dtype) == - std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype); - } - - // Given two pieces of static knowledge, calculate conservatively the - // information we can be sure about. - static ValueKnowledge join(const ValueKnowledge &lhs, - const ValueKnowledge &rhs) { - // Mental model: All conditions are checking how to change from the safe "no - // knowledge" default-initialized state to a state with more knowledge - // consistent with lhs and rhs. - ValueKnowledge result = getPessimisticValueState(nullptr); - - if (lhs.hasSizes && !rhs.hasSizes) { - result.hasSizes = true; - result.sizes = lhs.sizes; - } else if (!lhs.hasSizes && rhs.hasSizes) { - result.hasSizes = true; - result.sizes = rhs.sizes; - } else if (lhs.hasSizes && rhs.hasSizes && - lhs.sizes.size() == rhs.sizes.size()) { - result.hasSizes = true; - result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize); - for (int i = 0, e = result.sizes.size(); i != e; i++) { - int64_t lhsSize = lhs.sizes[i]; - int64_t rhsSize = rhs.sizes[i]; - int64_t &resultSize = result.sizes[i]; - if (lhsSize == ShapedType::kDynamicSize) { - resultSize = rhsSize; - } else if (rhsSize == ShapedType::kDynamicSize) { - resultSize = lhsSize; - } else if (lhsSize == rhsSize) { - resultSize = lhsSize; - } - } - } - - result.dtype = joinElementTypes(lhs.dtype, rhs.dtype); - return result; - } - - // Whether the Value is known to have a list of sizes. - bool hasSizes; - // If `hasSizes`, the sizes along each rank. Unknown sizes are represented as - // `ShapedType::kDynamicSize`. - std::vector sizes; - // The dtype of a tensor. - // This is equal to nullptr if we don't know that it is a specific concrete - // type. - Type dtype; -}; - -} // namespace + return; +} -/// Pass that enables broadcast by making all input arrays have the same -/// number of dimensions. Insert RESHAPE operations to lower rank operand -struct TosaInferShapes : public TosaInferShapesBase { -public: - void runOnFunction() override { - FuncOp func = getOperation(); +void PropagateRegion(Region ®ion) { + for (auto &block : region) { + for (Operation &op : block) { + if (op.getDialect()->getNamespace() != + tosa::TosaDialect::getDialectNamespace()) + continue; - IRRewriter rewriter(func.getContext()); + PropagateIfCond(op); - func.walk([&](Operation *op) { - if (op->getDialect()->getNamespace() != - tosa::TosaDialect::getDialectNamespace()) - return; InferShapedTypeOpInterface shapeInterface = dyn_cast(op); if (!shapeInterface) - return; + continue; SmallVector returnedShapes; if (shapeInterface .inferReturnTypeComponents( - op->getContext(), op->getLoc(), op->getOperands(), - op->getAttrDictionary(), op->getRegions(), returnedShapes) + op.getContext(), op.getLoc(), op.getOperands(), + op.getAttrDictionary(), op.getRegions(), returnedShapes) .succeeded()) { - for (auto it : llvm::zip(op->getResults(), returnedShapes)) { + for (auto it : llvm::zip(op.getResults(), returnedShapes)) { Value result = std::get<0>(it); ShapedTypeComponents predictedShape = std::get<1>(it); @@ -184,7 +103,7 @@ // Compute the knowledge based on the inferred type. auto inferredKnowledge = - ValueKnowledge::getPessimisticValueState(op->getContext()); + ValueKnowledge::getPessimisticValueState(op.getContext()); inferredKnowledge.dtype = resultTy.cast().getElementType(); inferredKnowledge.hasSizes = predictedShape.hasRank(); @@ -203,7 +122,20 @@ result.setType(newKnowledge.getType()); } } - }); + } + } +} + +/// Pass that enables broadcast by making all input arrays have the same +/// number of dimensions. Insert RESHAPE operations to lower rank operand +struct TosaInferShapes : public TosaInferShapesBase { +public: + void runOnFunction() override { + FuncOp func = getOperation(); + + IRRewriter rewriter(func.getContext()); + + PropagateRegion(func.body()); // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with // the FuncOp type. diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -708,3 +708,65 @@ %1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [2, 3]} : (tensor<3x11x12x7xf32>) -> tensor return } + +// ----- + +// CHECK-LABEL: @if_test_simple +func @if_test_simple(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> () { + // CHECK: (tensor, tensor, tensor) -> tensor + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb1(%arg3 : tensor, %arg4 : tensor): + "tosa.yield"(%arg3) : (tensor) -> () + }, { + ^bb1(%arg5 : tensor, %arg6 : tensor): + "tosa.yield"(%arg6) : (tensor) -> () + }) : (tensor, tensor, tensor) -> (tensor<*xf32>) + return +} + +// ----- + +// CHECK-LABEL: @if_test_dynamic +func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor) -> () { + // CHECK: (tensor, tensor<2xf32>, tensor<3xf32>) -> tensor + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb1(%arg3 : tensor<2xf32>, %arg4 : tensor<3xf32>): + "tosa.yield"(%arg3) : (tensor<2xf32>) -> () + }, { + ^bb1(%arg5 : tensor<2xf32>, %arg6 : tensor<3xf32>): + "tosa.yield"(%arg6) : (tensor<3xf32>) -> () + }) : (tensor, tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>) + return +} + +// ----- + +// CHECK-LABEL: @if_test_unranked +func @if_test_unranked(%arg0 : tensor, %arg1 : tensor<3xf32>, %arg2 : tensor) -> () { + // CHECK: (tensor, tensor, tensor<3xf32>) -> tensor<*xf32> + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb1(%arg3 : tensor, %arg4 : tensor<3xf32>): + "tosa.yield"(%arg3) : (tensor) -> () + }, { + ^bb1(%arg5 : tensor, %arg6 : tensor<3xf32>): + "tosa.yield"(%arg6) : (tensor<3xf32>) -> () + }) : (tensor, tensor, tensor<3xf32>) -> (tensor<*xf32>) + return +} + +// ----- + +// CHECK-LABEL: @if_test_propagate +func @if_test_propagate(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> () { + // CHECK: (tensor, tensor, tensor) -> tensor + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>): + %1 = "tosa.add"(%arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "tosa.yield"(%1) : (tensor<*xf32>) -> () + }, { + ^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>): + %1 = "tosa.sub"(%arg5, %arg6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "tosa.yield"(%1) : (tensor<*xf32>) -> () + }) : (tensor, tensor, tensor) -> (tensor<*xf32>) + return +}