diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1376,12 +1376,12 @@ }]; let arguments = (ins - Variadic:$input1, + Variadic:$input1, I64Attr:$axis ); let results = (outs - Tosa_RankedTensor:$output + Tosa_Tensor:$output ); let hasCanonicalizer = 1; @@ -1846,6 +1846,8 @@ //===----------------------------------------------------------------------===// def Tosa_WhileOp : Tosa_Op<"while_loop", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> { let summary = "output = input; While (Cond(output)) {output = Body(output)}"; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -23,6 +23,7 @@ #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseMap.h" using namespace mlir; using namespace mlir::tosa; @@ -1437,13 +1438,55 @@ } for (const ValueKnowledge &result : resultKnowledge) { - if (result.hasRank) { - inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes)); - } else { - inferredReturnShapes.push_back(ShapedTypeComponents()); + inferredReturnShapes.push_back(result.hasRank + ? ShapedTypeComponents(result.sizes) + : ShapedTypeComponents()); + } + + return success(); +} + +LogicalResult WhileOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector yieldOps; + for (auto &block : *regions[1]) + if (auto returnOp = dyn_cast(block.getTerminator())) + yieldOps.push_back(returnOp); + + if (yieldOps.empty()) + return failure(); + + // Get the initial type information from the operand types. + llvm::SmallVector resultKnowledge; + resultKnowledge.reserve(yieldOps.front().getNumOperands()); + for (auto operand : yieldOps.front().getOperands()) { + resultKnowledge.push_back( + ValueKnowledge::getKnowledgeFromType(operand.getType())); + } + + for (auto yieldOp : yieldOps) { + if (resultKnowledge.size() != yieldOp.getNumOperands()) + return failure(); + + for (auto it : llvm::enumerate(yieldOp.getOperands())) { + int32_t index = it.index(); + auto meet = ValueKnowledge::meet( + resultKnowledge[index], + ValueKnowledge::getKnowledgeFromType(it.value().getType())); + if (!meet) + continue; + resultKnowledge[index] = meet; } } + for (const ValueKnowledge &result : resultKnowledge) { + inferredReturnShapes.push_back(result.hasRank + ? ShapedTypeComponents(result.sizes) + : ShapedTypeComponents()); + } + return success(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -34,7 +34,8 @@ void propagateShapesInRegion(Region ®ion); -void propagateShapesToTosaIf(Operation &op) { +void propagateShapesToTosaIf( + Operation &op, DenseMap &shapesStorage) { tosa::IfOp ifOp = dyn_cast(op); if (!ifOp) return; @@ -44,6 +45,17 @@ if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) return; + for (unsigned int i = 1; i < op.getNumOperands(); i++) { + auto inferredTy = shapesStorage[op.getOperand(i)]; + auto blockArg = frontBlock.getArgument(i - 1); + auto oldType = blockArg.getType().cast(); + + if (inferredTy.hasRank()) { + Type newType = oldType.clone(inferredTy.getDims()); + blockArg.setType(newType); + } + } + for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( ifOp.getOperand(i + 1).getType()); @@ -58,8 +70,106 @@ propagateShapesInRegion(region); } +} + +void propagateShapesToTosaWhile( + Operation &op, DenseMap &shapesStorage) { + tosa::WhileOp whileOp = dyn_cast(op); + if (!whileOp) + return; + + // Determine what the expect argument types are to the cond/body blocks. + // These need to remain consisten across every iteration. + llvm::SmallVector argTypes; + for (auto operand : op.getOperands()) { + auto operandTy = operand.getType().cast(); + auto shapedTypeComponent = shapesStorage[operand]; + if (shapedTypeComponent.hasRank()) { + auto newTy = operandTy.clone(shapedTypeComponent.getDims()); + argTypes.push_back(newTy); + } else { + argTypes.push_back(operand.getType()); + } + } + + // Save out the type information so we can restore at the end. + llvm::DenseMap backupTypeMap; + for (auto &block : op.getRegion(1)) { + for (auto arg : block.getArguments()) + backupTypeMap[arg] = arg.getType(); + for (auto &op : block) + for (auto result : op.getResults()) + backupTypeMap[result] = result.getType(); + } + + bool hasNewTypes = true; + while (hasNewTypes) { + + // Set types on the block args. + Region &bodyRegion = op.getRegion(1); + Block &block = bodyRegion.front(); + for (int i = 0, s = argTypes.size(); i < s; i++) { + block.getArgument(i).setType(argTypes[i]); + } + + // Propagate to the end. + propagateShapesInRegion(bodyRegion); + + // Find all the tosa yield types and verify there is atleast one. + llvm::SmallVector yieldOps; + for (auto &block : bodyRegion) + if (auto returnOp = dyn_cast(block.getTerminator())) + yieldOps.push_back(returnOp); + + if (yieldOps.empty()) + return; + + // Using the new tosa yield types, infer the new subtypes. + llvm::SmallVector yieldTypeInfo; + for (auto ty : argTypes) { + yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty)); + } + + for (auto yieldOp : yieldOps) { + for (auto it : llvm::enumerate(yieldOp.getOperands())) { + auto newKnowledge = + ValueKnowledge::getKnowledgeFromType(it.value().getType()); + yieldTypeInfo[it.index()] = + ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge); + } + } + + // This should never happen. + if (yieldTypeInfo.size() != argTypes.size()) + return; + + // Determine the new block args and see if any changed. + hasNewTypes = false; + for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) { + Type newType = yieldTypeInfo[i].getType(); + hasNewTypes |= (newType != argTypes[i]); + argTypes[i] = newType; + } + + // Restore the original types in case we need to do a new pass. + for (auto &block : bodyRegion) { + for (auto arg : block.getArguments()) + arg.setType(backupTypeMap[arg]); + for (auto &op : block) + for (auto result : op.getResults()) + result.setType(backupTypeMap[result]); + } + } - return; + // Now we set the eventual types on the args of each block, then propagate + // through the region. + for (auto ®ion : op.getRegions()) { + for (unsigned int i = 0; i < argTypes.size(); i++) { + region.front().getArgument(i).setType(argTypes[i]); + } + + propagateShapesInRegion(region); + } } void propagateShapesInRegion(Region ®ion) { @@ -84,7 +194,8 @@ tosa::TosaDialect::getDialectNamespace()) continue; - propagateShapesToTosaIf(op); + propagateShapesToTosaIf(op, shapesStorage); + propagateShapesToTosaWhile(op, shapesStorage); InferShapedTypeOpInterface shapeInterface = dyn_cast(op); diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1043,14 +1043,16 @@ // CHECK-LABEL: @if_test_simple func @if_test_simple(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> () { + %a = "tosa.log"(%arg0) : (tensor) -> tensor<*xf32> + %b = "tosa.log"(%arg1) : (tensor) -> tensor<*xf32> // CHECK: (tensor, tensor, tensor) -> tensor - %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ - ^bb1(%arg3 : tensor, %arg4 : tensor): - "tosa.yield"(%arg3) : (tensor) -> () + %0 = "tosa.cond_if"(%arg2, %a, %b) ({ + ^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>): + "tosa.yield"(%arg3) : (tensor<*xf32>) -> () }, { - ^bb1(%arg5 : tensor, %arg6 : tensor): - "tosa.yield"(%arg6) : (tensor) -> () - }) : (tensor, tensor, tensor) -> (tensor<*xf32>) + ^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>): + "tosa.yield"(%arg6) : (tensor<*xf32>) -> () + }) : (tensor, tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>) return } @@ -1100,3 +1102,88 @@ }) : (tensor, tensor, tensor) -> (tensor<*xf32>) return } + +// ----- + +// CHECK-LABEL: @while_test +func @while_test(%arg0 : tensor) -> (tensor<*xi32>) { + // CHECK: "tosa.add" + // CHECK-SAME: (tensor, tensor) -> tensor + %0 = "tosa.add"(%arg0, %arg0) : (tensor, tensor) -> tensor<*xi32> + + // CHECK: "tosa.while_loop" + %1 = "tosa.while_loop"(%0) ( { + + // CHECK: ^bb0 + // CHECK-SAME: tensor + ^bb0(%arg2: tensor<*xi32>): + %2 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + // CHECK: "tosa.greater_equal" + // CHECK-SAME: (tensor, tensor) -> tensor + %3 = "tosa.greater_equal"(%2, %arg2) : (tensor, tensor<*xi32>) -> tensor<*xi1> + // CHECK: "tosa.yield" + // CHECK-SAME: tensor + "tosa.yield"(%3) : (tensor<*xi1>) -> () + }, { + // CHECK: ^bb0 + // CHECK-SAME: tensor + ^bb0(%arg2: tensor<*xi32>): + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + + // CHECK: "tosa.add" + // CHECK-SAME: (tensor, tensor) -> tensor + %3 = "tosa.add"(%arg2, %2) : (tensor<*xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.yield" + // CHECK-SAME: tensor + "tosa.yield"(%3) : (tensor<*xi32>) -> () + + // CHECK: (tensor) -> tensor + }) : (tensor<*xi32>) -> (tensor<*xi32>) + + // CHECK: tensor.cast + return %1 : tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @while_test +func @while_test(%arg0 : tensor, %arg1 : tensor<1xi32>) -> () { + // CHECK: "tosa.while_loop" + %1:2 = "tosa.while_loop"(%arg0, %arg1) ( { + + // CHECK: ^bb0 + // CHECK-SAME: tensor + // CHECK-SAME: tensor + ^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xi32>): + %2 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + + // CHECK: "tosa.greater_equal" + // CHECK-SAME: (tensor, tensor) -> tensor + %3 = "tosa.greater_equal"(%2, %arg2) : (tensor, tensor<*xi32>) -> tensor<*xi1> + "tosa.yield"(%3) : (tensor<*xi1>) -> () + }, { + + // CHECK: ^bb0 + // CHECK-SAME: tensor + // CHECK-SAME: tensor + ^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xi32>): + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + + // CHECK: "tosa.add" + // CHECK-SAME: (tensor, tensor) -> tensor + %3 = "tosa.add"(%arg2, %2) : (tensor<*xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.concat" + // CHECK-SAME: (tensor, tensor) -> tensor + %4 = "tosa.concat"(%arg3, %arg3) { axis = 0 : i64 } : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + + // CHECK: "tosa.yield" + // CHECK-SAME: tensor + // CHECK-SAME: tensor + "tosa.yield"(%3, %4) : (tensor<*xi32>, tensor<*xi32>) -> () + + // CHECK: (tensor, tensor<1xi32>) -> (tensor, tensor) + }) : (tensor, tensor<1xi32>) -> (tensor<*xi32>, tensor<*xi32>) + return +}