diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1789,6 +1789,8 @@ // Further described in docs/Rationale/RationaleTOSADialect.md . //===----------------------------------------------------------------------===// def Tosa_IfOp : Tosa_Op<"cond_if", [ + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> { let summary = "Conditional if operator"; diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h @@ -0,0 +1,178 @@ +//===-- ShapeUtils.h - TOSA shape support declarations ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Class declarations for shape utilities meant to assist shape propagation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H +#define MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace tosa { +/// Statically known information for a particular Value. +/// +/// This struct currently tracks only information relevant for tensor/array-like +/// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped +/// type as long as it is in the default "no knowledge" state returned by +/// `getPessimisticValueState`. The important invariant is that we cannot +/// claim to know something about a value which is false. +/// +/// This class could also be called "dataflow facts", "lattice value", etc. +struct ValueKnowledge { + ValueKnowledge() = delete; + ValueKnowledge(bool hasRank, llvm::ArrayRef newSizes, Type dtype) + : hasError(false), hasRank(hasRank), dtype(dtype) { + sizes.reserve(newSizes.size()); + for (auto size : newSizes) + sizes.push_back(size); + } + + operator bool() const { return !hasError; } + + // Get the static knowledge intrinsic to `type`. + static ValueKnowledge getKnowledgeFromType(Type type) { + ValueKnowledge result = getPessimisticValueState(); + if (auto shapedType = type.dyn_cast()) { + if (shapedType.hasRank()) { + result.hasRank = true; + result.sizes.reserve(shapedType.getRank()); + for (auto dim : shapedType.getShape()) + result.sizes.push_back(dim); + } + result.dtype = shapedType.getElementType(); + } + return result; + } + + // Return a pessimistic/conservative value state without assuming any knowlege + // about the IR. + static ValueKnowledge getPessimisticValueState() { + return ValueKnowledge(false, {}, Type()); + } + + Type getType() const { + if (hasRank) + return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype); + return UnrankedTensorType::get(dtype); + } + + bool operator==(const ValueKnowledge &rhs) const { + return hasRank == rhs.hasRank && sizes == rhs.sizes && dtype == rhs.dtype; + } + + // Given two pieces of static knowledge, calculate conservatively the + // information we can be sure about. + static ValueKnowledge join(const ValueKnowledge &lhs, + const ValueKnowledge &rhs) { + // Mental model: All conditions are checking how to change from the safe "no + // knowledge" default-initialized state to a state with more knowledge + // consistent with lhs and rhs. + ValueKnowledge result = getPessimisticValueState(); + result.hasError = true; + + if (!lhs || !rhs || lhs.dtype != rhs.dtype) + return result; + + result.hasError = false; + result.dtype = lhs.dtype; + + if (!lhs.hasRank && !rhs.hasRank) + return result; + + if (!rhs.hasRank) { + result.hasRank = true; + result.sizes = lhs.sizes; + return result; + } + + if (!lhs.hasRank) { + result.hasRank = true; + result.sizes = rhs.sizes; + return result; + } + + if (lhs.sizes.size() != rhs.sizes.size()) + return result; + + result.hasRank = true; + result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize); + for (auto i : llvm::seq(0, result.sizes.size())) { + int64_t lhsSize = lhs.sizes[i]; + int64_t rhsSize = rhs.sizes[i]; + int64_t &resultSize = result.sizes[i]; + if (lhsSize == ShapedType::kDynamicSize) { + resultSize = rhsSize; + } else if (rhsSize == ShapedType::kDynamicSize) { + resultSize = lhsSize; + } else if (lhsSize == rhsSize) { + resultSize = lhsSize; + } else { + result.hasError = true; + } + } + + return result; + } + + // Given to types, generate a new ValueKnowledge that meets to cover both + // cases. E.g. if the rank of the LHS and RHS differ, the resulting tensor + // has unknown rank. + static ValueKnowledge meet(const ValueKnowledge &lhs, + const ValueKnowledge &rhs) { + ValueKnowledge result = getPessimisticValueState(); + result.hasError = true; + + if (!rhs || !rhs || lhs.dtype != rhs.dtype) + return result; + + result.hasError = false; + result.dtype = lhs.dtype; + + if (!lhs.hasRank || !rhs.hasRank) { + result.hasRank = false; + return result; + } + + if (lhs.sizes.size() != rhs.sizes.size()) { + result.hasRank = false; + return result; + } + + result.hasRank = true; + result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize); + for (int i = 0, e = lhs.sizes.size(); i < e; i++) { + if (lhs.sizes[i] == rhs.sizes[i]) { + result.sizes[i] = lhs.sizes[i]; + } + } + + return result; + } + + // Whether the value information has an error. + bool hasError; + // Whether the value has known rank. + bool hasRank; + // If `hasRank`, the sizes along each rank. Unknown sizes are represented as + // `ShapedType::kDynamicSize`. + llvm::SmallVector sizes; + // The dtype of a tensor. + // This is equal to nullptr if we don't know that it is a specific concrete + // type. + Type dtype; +}; +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -1301,6 +1302,54 @@ return success(); } +LogicalResult IfOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector yieldOps; + for (Region *region : regions) { + for (auto &block : *region) + if (auto returnOp = dyn_cast(block.getTerminator())) + yieldOps.push_back(returnOp); + } + + if (yieldOps.empty()) + return failure(); + + // Get the initial type information for the yield op. + llvm::SmallVector resultKnowledge; + resultKnowledge.reserve(yieldOps.front().getNumOperands()); + for (auto operand : yieldOps.front().getOperands()) { + resultKnowledge.push_back( + ValueKnowledge::getKnowledgeFromType(operand.getType())); + } + + for (auto yieldOp : yieldOps) { + if (resultKnowledge.size() != yieldOp.getNumOperands()) + return failure(); + + for (auto it : llvm::enumerate(yieldOp.getOperands())) { + int32_t index = it.index(); + auto meet = ValueKnowledge::meet( + resultKnowledge[index], + ValueKnowledge::getKnowledgeFromType(it.value().getType())); + if (!meet) + continue; + resultKnowledge[index] = meet; + } + } + + for (auto result : resultKnowledge) { + if (result.hasRank) { + inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes)); + } else { + inferredReturnShapes.push_back(ShapedTypeComponents()); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -30,137 +31,57 @@ namespace { -// ----------------------------------------------------------------------------- -// Analysis. -// ----------------------------------------------------------------------------- - -static Type joinElementTypes(Type lhs, Type rhs) { - return lhs == rhs ? lhs : Type(); -} - -namespace { -// Statically known information for a particular Value. -// -// This struct currently tracks only information relevant for tensor/array-like -// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped -// type as long as it is in the default "no knowledge" state returned by -// `getPessimisticValueState`. The important invariant is that we cannot -// claim to know something about a value which is false. -// -// This class could also be called "dataflow facts", "lattice value", etc. -struct ValueKnowledge { - ValueKnowledge() = delete; - ValueKnowledge(bool hasSizes, std::vector sizes, Type dtype) - : hasSizes(hasSizes), sizes(sizes), dtype(dtype) { - assert(sizes.size() == 0 || hasSizes); - } - - // Get the static knowledge intrinsic to `type`. - static ValueKnowledge getKnowledgeFromType(Type type) { - ValueKnowledge result = getPessimisticValueState(type.getContext()); - if (auto shapedType = type.dyn_cast()) { - if (shapedType.hasRank()) { - result.hasSizes = true; - result.sizes = shapedType.getShape(); - } - result.dtype = shapedType.getElementType(); +void propagateShapesInRegion(Region ®ion); + +void propagateShapesToTosaIf(Operation &op) { + tosa::IfOp ifOp = dyn_cast(op); + if (!ifOp) + return; + + for (auto ®ion : op.getRegions()) { + Block &frontBlock = region.front(); + if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) + return; + + for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { + ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( + ifOp.getOperand(i + 1).getType()); + ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType( + frontBlock.getArgument(i).getType()); + ValueKnowledge joinedKnowledge = + ValueKnowledge::join(operandKnowledge, blockKnowledge); + if (!joinedKnowledge) + continue; + frontBlock.getArgument(i).setType(joinedKnowledge.getType()); } - return result; - } - // Return a pessimistic/conservative value state without assuming any knowlege - // about the IR. - static ValueKnowledge getPessimisticValueState(MLIRContext *context) { - return ValueKnowledge(false, {}, Type()); + propagateShapesInRegion(region); } - Type getType() const { - if (hasSizes) { - return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype); - } - return UnrankedTensorType::get(dtype); - } - - bool operator==(const ValueKnowledge &rhs) const { - return std::make_tuple(hasSizes, sizes, dtype) == - std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype); - } - - // Given two pieces of static knowledge, calculate conservatively the - // information we can be sure about. - static ValueKnowledge join(const ValueKnowledge &lhs, - const ValueKnowledge &rhs) { - // Mental model: All conditions are checking how to change from the safe "no - // knowledge" default-initialized state to a state with more knowledge - // consistent with lhs and rhs. - ValueKnowledge result = getPessimisticValueState(nullptr); - - if (lhs.hasSizes && !rhs.hasSizes) { - result.hasSizes = true; - result.sizes = lhs.sizes; - } else if (!lhs.hasSizes && rhs.hasSizes) { - result.hasSizes = true; - result.sizes = rhs.sizes; - } else if (lhs.hasSizes && rhs.hasSizes && - lhs.sizes.size() == rhs.sizes.size()) { - result.hasSizes = true; - result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize); - for (int i = 0, e = result.sizes.size(); i != e; i++) { - int64_t lhsSize = lhs.sizes[i]; - int64_t rhsSize = rhs.sizes[i]; - int64_t &resultSize = result.sizes[i]; - if (lhsSize == ShapedType::kDynamicSize) { - resultSize = rhsSize; - } else if (rhsSize == ShapedType::kDynamicSize) { - resultSize = lhsSize; - } else if (lhsSize == rhsSize) { - resultSize = lhsSize; - } - } - } - - result.dtype = joinElementTypes(lhs.dtype, rhs.dtype); - return result; - } - - // Whether the Value is known to have a list of sizes. - bool hasSizes; - // If `hasSizes`, the sizes along each rank. Unknown sizes are represented as - // `ShapedType::kDynamicSize`. - std::vector sizes; - // The dtype of a tensor. - // This is equal to nullptr if we don't know that it is a specific concrete - // type. - Type dtype; -}; - -} // namespace + return; +} -/// Pass that enables broadcast by making all input arrays have the same -/// number of dimensions. Insert RESHAPE operations to lower rank operand -struct TosaInferShapes : public TosaInferShapesBase { -public: - void runOnFunction() override { - FuncOp func = getOperation(); +void propagateShapesInRegion(Region ®ion) { + for (auto &block : region) { + for (Operation &op : block) { + if (op.getDialect()->getNamespace() != + tosa::TosaDialect::getDialectNamespace()) + continue; - IRRewriter rewriter(func.getContext()); + propagateShapesToTosaIf(op); - func.walk([&](Operation *op) { - if (op->getDialect()->getNamespace() != - tosa::TosaDialect::getDialectNamespace()) - return; InferShapedTypeOpInterface shapeInterface = dyn_cast(op); if (!shapeInterface) - return; + continue; SmallVector returnedShapes; if (shapeInterface .inferReturnTypeComponents( - op->getContext(), op->getLoc(), op->getOperands(), - op->getAttrDictionary(), op->getRegions(), returnedShapes) + op.getContext(), op.getLoc(), op.getOperands(), + op.getAttrDictionary(), op.getRegions(), returnedShapes) .succeeded()) { - for (auto it : llvm::zip(op->getResults(), returnedShapes)) { + for (auto it : llvm::zip(op.getResults(), returnedShapes)) { Value result = std::get<0>(it); ShapedTypeComponents predictedShape = std::get<1>(it); @@ -183,11 +104,10 @@ ValueKnowledge::getKnowledgeFromType(resultTy); // Compute the knowledge based on the inferred type. - auto inferredKnowledge = - ValueKnowledge::getPessimisticValueState(op->getContext()); + auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); inferredKnowledge.dtype = resultTy.cast().getElementType(); - inferredKnowledge.hasSizes = predictedShape.hasRank(); + inferredKnowledge.hasRank = predictedShape.hasRank(); if (predictedShape.hasRank()) { for (auto dim : predictedShape.getDims()) { inferredKnowledge.sizes.push_back(dim); @@ -200,10 +120,25 @@ // Compute the new type based on the joined version. auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge); + if (!newKnowledge) + continue; result.setType(newKnowledge.getType()); } } - }); + } + } +} + +/// Pass that performs shape propagation across TOSA operations. This includes +/// migrating to within the regions of if/while operations. +struct TosaInferShapes : public TosaInferShapesBase { +public: + void runOnFunction() override { + FuncOp func = getOperation(); + + IRRewriter rewriter(func.getContext()); + + propagateShapesInRegion(func.body()); // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with // the FuncOp type. diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -774,7 +774,6 @@ // ----- - // CHECK-LABEL: @conv2d_strided func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () { // CHECK: -> tensor<1x5x7x1xf32> @@ -1033,12 +1032,71 @@ %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [5.000000e-01 : f32, 1.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor return } - -// ----- - // CHECK-LABEL: @resize_fp_offsetted func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) { // CHECK: -> tensor<1x4x6x1xi32> %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [2.500000e-01 : f32, 2.500000e-01 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [2.500000e-01 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor return } + +// ----- + +// CHECK-LABEL: @if_test_simple +func @if_test_simple(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> () { + // CHECK: (tensor, tensor, tensor) -> tensor + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb1(%arg3 : tensor, %arg4 : tensor): + "tosa.yield"(%arg3) : (tensor) -> () + }, { + ^bb1(%arg5 : tensor, %arg6 : tensor): + "tosa.yield"(%arg6) : (tensor) -> () + }) : (tensor, tensor, tensor) -> (tensor<*xf32>) + return +} + +// ----- + +// CHECK-LABEL: @if_test_dynamic +func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor) -> () { + // CHECK: (tensor, tensor<2xf32>, tensor<3xf32>) -> tensor + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb1(%arg3 : tensor<2xf32>, %arg4 : tensor<3xf32>): + "tosa.yield"(%arg3) : (tensor<2xf32>) -> () + }, { + ^bb1(%arg5 : tensor<2xf32>, %arg6 : tensor<3xf32>): + "tosa.yield"(%arg6) : (tensor<3xf32>) -> () + }) : (tensor, tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>) + return +} + +// ----- + +// CHECK-LABEL: @if_test_unranked +func @if_test_unranked(%arg0 : tensor, %arg1 : tensor<3xf32>, %arg2 : tensor) -> () { + // CHECK: (tensor, tensor, tensor<3xf32>) -> tensor<*xf32> + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb1(%arg3 : tensor, %arg4 : tensor<3xf32>): + "tosa.yield"(%arg3) : (tensor) -> () + }, { + ^bb1(%arg5 : tensor, %arg6 : tensor<3xf32>): + "tosa.yield"(%arg6) : (tensor<3xf32>) -> () + }) : (tensor, tensor, tensor<3xf32>) -> (tensor<*xf32>) + return +} + +// ----- + +// CHECK-LABEL: @if_test_propagate +func @if_test_propagate(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> () { + // CHECK: (tensor, tensor, tensor) -> tensor + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>): + %1 = "tosa.add"(%arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "tosa.yield"(%1) : (tensor<*xf32>) -> () + }, { + ^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>): + %1 = "tosa.sub"(%arg5, %arg6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "tosa.yield"(%1) : (tensor<*xf32>) -> () + }) : (tensor, tensor, tensor) -> (tensor<*xf32>) + return +}