diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -201,6 +202,16 @@ return it->second; }; + // Check whether this use case is replaceable. We define an op as + // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a + // type-inference related interface. + auto isReplaceableUser = [](Operation *user) -> bool { + return isa(user) || + user->getDialect()->getNamespace() == + TosaDialect::getDialectNamespace() || + isa(user); + }; + for (auto &block : region) { for (Operation &op : block) { if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) @@ -229,16 +240,8 @@ // Check whether this use case is replaceable. We define an op as // being replaceable if it is used by a ReturnOp or a TosaOp. - bool replaceable = true; - for (auto *user : result.getUsers()) { - if (isa(user)) - continue; - if (user->getDialect()->getNamespace() == - TosaDialect::getDialectNamespace()) - continue; - - replaceable = false; - } + if (!llvm::all_of(result.getUsers(), isReplaceableUser)) + continue; // Determine the knowledge based on the output type. // TODO: should also query WIP type probably @@ -256,9 +259,6 @@ } } - if (!replaceable) - continue; - // Compute the new type based on the joined version. auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge); diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1234,3 +1234,33 @@ return } + +// ----- + +// CHECK-LABEL: test_non_tosa_consumer_shape +func.func @test_non_tosa_consumer_shape(%arg0: tensor<4x4xf32>) -> !shape.shape { + // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> + %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32> + %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape + return %1 : !shape.shape +} + +// ----- + +// CHECK-LABEL: test_non_tosa_consumer_shape +func.func @test_non_tosa_consumer_shape2(%arg0: tensor<4x4xf32>) -> tensor { + // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> + %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32> + %1 = shape.shape_of %0 : tensor<*xf32> -> tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: test_non_tosa_consumer_extract +func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index) -> f32 { + // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> + %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor + %1 = tensor.extract %0[%arg1, %arg1] : tensor + return %1 : f32 +}