diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -38,6 +38,7 @@ }]; let cppNamespace = "mlir::tosa"; + let hasConstantMaterializer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1512,12 +1512,13 @@ }]; let arguments = (ins - AnyAttr:$value + ElementsAttr:$value ); let results = (outs Tosa_TensorUpto4D:$output ); + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -89,6 +89,24 @@ addInterfaces(); } +Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + // Tosa dialect constants only support ElementsAttr unlike standard dialect + // constant which supports all attributes. + if (value.isa()) + return builder.create(loc, type, value.cast()); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// Operator Folders. +//===----------------------------------------------------------------------===// + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + return valueAttr(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt --test-constant-fold %s | FileCheck %s + +// CHECK-LABEL: func @test_const +func @test_const(%arg0 : index) -> tensor<4xi32> { + // CHECK: "tosa.const" + %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +}