diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -203,8 +203,7 @@ /// detensored, then: /// - opsToDetensor should be = {linalg.generic{add}}. /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. - virtual void compute(FunctionOpInterface func, - DetensorizeTypeConverter typeConverter, + virtual void compute(Operation *op, DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, DenseSet &blockArgsToDetensor) = 0; @@ -255,17 +254,16 @@ /// AND can be detensored. class ControlFlowDetectionModel : public CostModel { public: - void compute(FunctionOpInterface func, - DetensorizeTypeConverter typeConverter, + void compute(Operation *op, DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, DenseSet &blockArgsToDetensor) override { SmallVector workList; - func->walk([&](cf::CondBranchOp condBr) { + op->walk([&](cf::CondBranchOp condBr) { llvm::append_range(workList, condBr.getOperands()); }); - func->walk([&](cf::BranchOp br) { + op->walk([&](cf::BranchOp br) { llvm::append_range(workList, br.getOperands()); }); @@ -451,18 +449,19 @@ /// Detensorize everything that can detensored. class AggressiveDetensoringModel : public CostModel { public: - void compute(FunctionOpInterface func, - DetensorizeTypeConverter typeConverter, + void compute(Operation *op, DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, DenseSet &blockArgsToDetensor) override { - func->walk([&](GenericOp genericOp) { + op->walk([&](GenericOp genericOp) { if (shouldBeDetensored(genericOp, typeConverter)) opsToDetensor.insert(genericOp); }); - for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1)) - for (BlockArgument blockArgument : block.getArguments()) - blockArgsToDetensor.insert(blockArgument); + if (auto func = dyn_cast(op)) { + for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1)) + for (BlockArgument blockArgument : block.getArguments()) + blockArgsToDetensor.insert(blockArgument); + } } }; @@ -474,16 +473,14 @@ DenseSet opsToDetensor; DenseMap> detensorableBranchOps; DenseSet blockArgsToDetensor; - FunctionOpInterface funcOp = cast(getOperation()); + Operation *op = getOperation(); if (aggressiveMode.getValue()) { AggressiveDetensoringModel costModel; - costModel.compute(funcOp, typeConverter, opsToDetensor, - blockArgsToDetensor); + costModel.compute(op, typeConverter, opsToDetensor, blockArgsToDetensor); } else { ControlFlowDetectionModel costModel; - costModel.compute(funcOp, typeConverter, opsToDetensor, - blockArgsToDetensor); + costModel.compute(op, typeConverter, opsToDetensor, blockArgsToDetensor); } detensorableBranchOps = diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize{aggressive-mode}))" | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize="aggressive-mode" | FileCheck %s #map = affine_map<() -> ()> @@ -100,3 +101,9 @@ // CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]]) // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] // CHECK: return %[[new_tensor_res]] + +// CHECK-LABEL: func @regression() +func.func @regression() { + %0 = tensor.empty() : tensor<4xf32> + return +}