diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -18,6 +18,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/VectorInterfaces.h" //===----------------------------------------------------------------------===// // TOSA dialect and structs includes. diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td @@ -19,12 +19,16 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" include "mlir/Dialect/Tosa/IR/TosaTypesBase.td" include "mlir/Dialect/Tosa/IR/TosaOpBase.td" -def Tosa_ApplyScaleOp: Tosa_Op<"apply_scale", [Pure] # ElementwiseMappable.traits> { +def Tosa_ApplyScaleOp : + Tosa_Op<"apply_scale", + [Pure, DeclareOpInterfaceMethods] # + ElementwiseMappable.traits> { let summary = "Rescale scalar operator for Tosa tensor operators"; let description = [{ @@ -46,6 +50,10 @@ let results = (outs Tosa_IntLike:$output ); + + let extraClassDeclaration = [{ + std::optional> getShapeForUnroll(); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -21,6 +21,7 @@ MLIRQuantUtils MLIRSideEffectInterfaces MLIRTensorDialect + MLIRVectorInterfaces MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1402,6 +1402,12 @@ return success(); } +std::optional> ApplyScaleOp::getShapeForUnroll() { + if (auto vt = getType().dyn_cast()) + return llvm::to_vector<4>(vt.getShape()); + return std::nullopt; +} + //===----------------------------------------------------------------------===// // TOSA Attribute Definitions. //===----------------------------------------------------------------------===// diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8797,6 +8797,7 @@ ":LoopLikeInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", + ":VectorInterfacesTdFiles", ], )