diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1376,12 +1376,12 @@ }]; let arguments = (ins - Variadic:$input1, + Variadic:$input1, I64Attr:$axis ); let results = (outs - Tosa_RankedTensor:$output + Tosa_Tensor:$output ); let hasCanonicalizer = 1; @@ -1846,6 +1846,8 @@ //===----------------------------------------------------------------------===// def Tosa_WhileOp : Tosa_Op<"while_loop", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> { let summary = "output = input; While (Cond(output)) {output = Body(output)}"; diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" @@ -61,6 +62,10 @@ return ValueKnowledge(false, {}, Type()); } + ShapedTypeComponents getShapedTypeComponents() const { + return hasRank ? ShapedTypeComponents(sizes) : ShapedTypeComponents(); + } + Type getType() const { if (hasRank) return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -23,6 +23,7 @@ #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseMap.h" using namespace mlir; using namespace mlir::tosa; @@ -1437,13 +1438,52 @@ } for (const ValueKnowledge &result : resultKnowledge) { - if (result.hasRank) { - inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes)); - } else { - inferredReturnShapes.push_back(ShapedTypeComponents()); + inferredReturnShapes.push_back(result.getShapedTypeComponents()); + } + + return success(); +} + +LogicalResult WhileOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector yieldOps; + for (auto &block : *regions[1]) + if (auto returnOp = dyn_cast(block.getTerminator())) + yieldOps.push_back(returnOp); + + // TOSA's while must have a tosa.yield as its terminator. If not found this + // tosa.while is invalid. + if (yieldOps.empty()) + return failure(); + + // Get the initial type information from the operand types. + llvm::SmallVector resultKnowledge; + resultKnowledge.reserve(yieldOps.front().getNumOperands()); + for (auto operand : yieldOps.front().getOperands()) { + resultKnowledge.push_back( + ValueKnowledge::getKnowledgeFromType(operand.getType())); + } + + for (auto yieldOp : yieldOps) { + if (resultKnowledge.size() != yieldOp.getNumOperands()) + return failure(); + + for (auto it : llvm::enumerate(yieldOp.getOperands())) { + int32_t index = it.index(); + if (auto meet = ValueKnowledge::meet( + resultKnowledge[index], + ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { + resultKnowledge[index] = meet; + }; } } + for (const ValueKnowledge &result : resultKnowledge) { + inferredReturnShapes.push_back(result.getShapedTypeComponents()); + } + return success(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -34,8 +34,9 @@ void propagateShapesInRegion(Region ®ion); -void propagateShapesToTosaIf(Operation &op) { - tosa::IfOp ifOp = dyn_cast(op); +void propagateShapesToTosaIf( + Operation &op, DenseMap &shapesStorage) { + IfOp ifOp = dyn_cast(op); if (!ifOp) return; @@ -44,6 +45,17 @@ if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) return; + for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { + auto inferredTy = shapesStorage[op.getOperand(i)]; + auto blockArg = frontBlock.getArgument(i - 1); + auto oldType = blockArg.getType().cast(); + + if (inferredTy.hasRank()) { + Type newType = oldType.clone(inferredTy.getDims()); + blockArg.setType(newType); + } + } + for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( ifOp.getOperand(i + 1).getType()); @@ -58,8 +70,110 @@ propagateShapesInRegion(region); } +} + +void propagateShapesToTosaWhile( + Operation &op, DenseMap &shapesStorage) { + WhileOp whileOp = dyn_cast(op); + if (!whileOp) + return; + + // Determine what the expected argument types are to the cond/body blocks. + // The expected arguments should be compatible with ever iteration of the + // loop body / condition for tosa.while. + llvm::SmallVector argTypes; + for (auto operand : op.getOperands()) { + auto operandTy = operand.getType().cast(); + auto shapedTypeComponent = shapesStorage[operand]; + if (shapedTypeComponent.hasRank()) { + auto newTy = operandTy.clone(shapedTypeComponent.getDims()); + argTypes.push_back(newTy); + } else { + argTypes.push_back(operand.getType()); + } + } + + // Save out the type information so we can restore at the end. + llvm::DenseMap originalTypeMap; + for (auto &block : op.getRegion(1)) { + for (auto arg : block.getArguments()) + originalTypeMap[arg] = arg.getType(); + for (auto &op : block) + for (auto result : op.getResults()) + originalTypeMap[result] = result.getType(); + } + + bool hasNewTypes = true; + while (hasNewTypes) { + + // Set types on the block args. + Region &bodyRegion = op.getRegion(1); + Block &block = bodyRegion.front(); + for (int i = 0, s = argTypes.size(); i < s; i++) { + block.getArgument(i).setType(argTypes[i]); + } + + // Propagate to the end. + propagateShapesInRegion(bodyRegion); + + // Find all the tosa yield types and verify there is atleast one. + llvm::SmallVector yieldOps; + for (auto &block : bodyRegion) + if (auto yieldOp = dyn_cast(block.getTerminator())) + yieldOps.push_back(yieldOp); + + if (yieldOps.empty()) + return; + + // Using the new tosa yield types, infer the new subtypes. + llvm::SmallVector yieldTypeInfo; + for (auto ty : argTypes) { + yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty)); + } + + for (auto yieldOp : yieldOps) { + for (auto it : llvm::enumerate(yieldOp.getOperands())) { + auto newKnowledge = + ValueKnowledge::getKnowledgeFromType(it.value().getType()); + yieldTypeInfo[it.index()] = + ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge); + } + } + + // This should never happen. + if (yieldTypeInfo.size() != argTypes.size()) + return; + + // Determine the new block args and see if any changed. + hasNewTypes = false; + for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) { + Type newType = yieldTypeInfo[i].getType(); + hasNewTypes |= (newType != argTypes[i]); + argTypes[i] = newType; + } + + // The types inferred in the block assume the operand types specified for + // this iteration. We need to restore the original types to ensure that + // future iterations only use the already specified types, not possible + // types from previous iterations. + for (auto &block : bodyRegion) { + for (auto arg : block.getArguments()) + arg.setType(originalTypeMap[arg]); + for (auto &op : block) + for (auto result : op.getResults()) + result.setType(originalTypeMap[result]); + } + } - return; + // We now set the block arguments according to the most recent shape + // inference results. + for (auto ®ion : op.getRegions()) { + for (unsigned int i = 0; i < argTypes.size(); i++) { + region.front().getArgument(i).setType(argTypes[i]); + } + + propagateShapesInRegion(region); + } } void propagateShapesInRegion(Region ®ion) { @@ -80,11 +194,11 @@ for (auto &block : region) { for (Operation &op : block) { - if (op.getDialect()->getNamespace() != - tosa::TosaDialect::getDialectNamespace()) + if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) continue; - propagateShapesToTosaIf(op); + propagateShapesToTosaIf(op, shapesStorage); + propagateShapesToTosaWhile(op, shapesStorage); InferShapedTypeOpInterface shapeInterface = dyn_cast(op); @@ -110,7 +224,7 @@ if (isa(user)) continue; if (user->getDialect()->getNamespace() == - tosa::TosaDialect::getDialectNamespace()) + TosaDialect::getDialectNamespace()) continue; replaceable = false; diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1043,14 +1043,16 @@ // CHECK-LABEL: @if_test_simple func @if_test_simple(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> () { + %a = "tosa.log"(%arg0) : (tensor) -> tensor<*xf32> + %b = "tosa.log"(%arg1) : (tensor) -> tensor<*xf32> // CHECK: (tensor, tensor, tensor) -> tensor - %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ - ^bb1(%arg3 : tensor, %arg4 : tensor): - "tosa.yield"(%arg3) : (tensor) -> () + %0 = "tosa.cond_if"(%arg2, %a, %b) ({ + ^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>): + "tosa.yield"(%arg3) : (tensor<*xf32>) -> () }, { - ^bb1(%arg5 : tensor, %arg6 : tensor): - "tosa.yield"(%arg6) : (tensor) -> () - }) : (tensor, tensor, tensor) -> (tensor<*xf32>) + ^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>): + "tosa.yield"(%arg6) : (tensor<*xf32>) -> () + }) : (tensor, tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>) return } @@ -1100,3 +1102,88 @@ }) : (tensor, tensor, tensor) -> (tensor<*xf32>) return } + +// ----- + +// CHECK-LABEL: @while_test +func @while_test(%arg0 : tensor) -> (tensor<*xi32>) { + // CHECK: "tosa.add" + // CHECK-SAME: (tensor, tensor) -> tensor + %0 = "tosa.add"(%arg0, %arg0) : (tensor, tensor) -> tensor<*xi32> + + // CHECK: "tosa.while_loop" + %1 = "tosa.while_loop"(%0) ( { + + // CHECK: ^bb0 + // CHECK-SAME: tensor + ^bb0(%arg2: tensor<*xi32>): + %2 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + // CHECK: "tosa.greater_equal" + // CHECK-SAME: (tensor, tensor) -> tensor + %3 = "tosa.greater_equal"(%2, %arg2) : (tensor, tensor<*xi32>) -> tensor<*xi1> + // CHECK: "tosa.yield" + // CHECK-SAME: tensor + "tosa.yield"(%3) : (tensor<*xi1>) -> () + }, { + // CHECK: ^bb0 + // CHECK-SAME: tensor + ^bb0(%arg2: tensor<*xi32>): + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + + // CHECK: "tosa.add" + // CHECK-SAME: (tensor, tensor) -> tensor + %3 = "tosa.add"(%arg2, %2) : (tensor<*xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.yield" + // CHECK-SAME: tensor + "tosa.yield"(%3) : (tensor<*xi32>) -> () + + // CHECK: (tensor) -> tensor + }) : (tensor<*xi32>) -> (tensor<*xi32>) + + // CHECK: tensor.cast + return %1 : tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @while_test +func @while_test(%arg0 : tensor, %arg1 : tensor<1xi32>) -> () { + // CHECK: "tosa.while_loop" + %1:2 = "tosa.while_loop"(%arg0, %arg1) ( { + + // CHECK: ^bb0 + // CHECK-SAME: tensor + // CHECK-SAME: tensor + ^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xi32>): + %2 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + + // CHECK: "tosa.greater_equal" + // CHECK-SAME: (tensor, tensor) -> tensor + %3 = "tosa.greater_equal"(%2, %arg2) : (tensor, tensor<*xi32>) -> tensor<*xi1> + "tosa.yield"(%3) : (tensor<*xi1>) -> () + }, { + + // CHECK: ^bb0 + // CHECK-SAME: tensor + // CHECK-SAME: tensor + ^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xi32>): + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + + // CHECK: "tosa.add" + // CHECK-SAME: (tensor, tensor) -> tensor + %3 = "tosa.add"(%arg2, %2) : (tensor<*xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.concat" + // CHECK-SAME: (tensor, tensor) -> tensor + %4 = "tosa.concat"(%arg3, %arg3) { axis = 0 : i64 } : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + + // CHECK: "tosa.yield" + // CHECK-SAME: tensor + // CHECK-SAME: tensor + "tosa.yield"(%3, %4) : (tensor<*xi32>, tensor<*xi32>) -> () + + // CHECK: (tensor, tensor<1xi32>) -> (tensor, tensor) + }) : (tensor, tensor<1xi32>) -> (tensor<*xi32>, tensor<*xi32>) + return +}