diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -88,6 +88,11 @@ static constexpr int64_t kDynamicStrideOrOffset = std::numeric_limits::min(); + /// Return clone of this type with new shape and element type. + ShapedType clone(ArrayRef shape, Type elementType); + ShapedType clone(ArrayRef shape); + ShapedType clone(Type elementType); + /// Return the element type. Type getElementType() const; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -197,6 +197,58 @@ constexpr int64_t ShapedType::kDynamicSize; constexpr int64_t ShapedType::kDynamicStrideOrOffset; +ShapedType ShapedType::clone(ArrayRef shape, Type elementType) { + if (auto other = dyn_cast()) { + MemRefType::Builder b(other); + b.setShape(shape); + b.setElementType(elementType); + return b; + } + + if (isa()) + return RankedTensorType::get(shape, elementType); + + if (isa()) + return VectorType::get(shape, elementType); + + llvm_unreachable("Unhandled ShapedType clone case"); +} + +ShapedType ShapedType::clone(ArrayRef shape) { + if (auto other = dyn_cast()) { + MemRefType::Builder b(other); + b.setShape(shape); + return b; + } + + if (isa()) + return RankedTensorType::get(shape, getElementType()); + + if (isa()) + return VectorType::get(shape, getElementType()); + + llvm_unreachable("Unhandled ShapedType clone case"); +} + +ShapedType ShapedType::clone(Type elementType) { + if (auto other = dyn_cast()) { + MemRefType::Builder b(other); + b.setElementType(elementType); + return b; + } + + if (isa()) { + if (hasRank()) + return RankedTensorType::get(getShape(), elementType); + return UnrankedTensorType::get(elementType); + } + + if (isa()) { + return VectorType::get(getShape(), elementType); + } + llvm_unreachable("Unhandled ShapedType clone hit"); +} + Type ShapedType::getElementType() const { return static_cast(impl)->elementType; } diff --git a/mlir/test/IR/update-shape.mlir b/mlir/test/IR/update-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/update-shape.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -test-update-element-or-type --mlir-print-local-scope | FileCheck %s + +func @succeededSameOperandsElementType() { + // CHECK: tensor<121xf32> + "test.update_variadic_return"() {type_attr = f32} : () -> (tensor<121xi32>) + // CHECK: vector<16xf32> + "test.update_variadic_return"() {shape_attr = dense<16> : vector<1xi32>} : () -> (vector<121xf32>) + // CHECK: memref<64x256xi32, affine_map<(d0, d1) -> (d1, d0)>> + "test.update_variadic_return"() {type_attr = i32} : () -> (memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>) + // CHECK: memref<10x20xi32, affine_map<(d0, d1) -> (d0, d1 * 2)>> + "test.update_variadic_return"() {shape_attr = dense<[10, 20]> : tensor<2xi32>, type_attr = i32} : () -> (memref<64x256xf32, affine_map<(d0, d1) -> (d0, 2 * d1)>>) + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -90,6 +90,14 @@ ); } +def UpdatedVariadicReturnOp : TEST_Op<"update_variadic_return"> { + let arguments = (ins + OptionalAttr:$type_attr, + OptionalAttr:$shape_attr + ); + let results = (outs Variadic); +} + def TEST_TestType : DialectType()">, "test">, BuildableType<"$_builder.getType<::mlir::test::TestType>()">; diff --git a/mlir/test/lib/Transforms/TestUpdateShapedTypes.cpp b/mlir/test/lib/Transforms/TestUpdateShapedTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestUpdateShapedTypes.cpp @@ -0,0 +1,58 @@ +//===- TestUpdateShapedTypes.cpp - Pass to test updating element & shape --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Pass that changes element or shape of shaped types. +struct TestChangeElementOrShape + : public PassWrapper> { + + void runOnOperation() override { + getOperation().walk([&](test::UpdatedVariadicReturnOp op) { + Optional elementType = op.type_attr(); + Optional shapeAttr = op.shape_attr(); + if (!elementType && !shapeAttr) + return; + + for (auto result : op.getOperation()->getOpResults()) { + SmallVector shape; + if (shapeAttr) { + shape.reserve(shapeAttr->size()); + for (APInt v : shapeAttr->getIntValues()) + shape.push_back(v.getSExtValue()); + } + + if (elementType && shapeAttr) { + result.setType( + result.getType().cast().clone(shape, *elementType)); + } else if (elementType) { + result.setType( + result.getType().cast().clone(*elementType)); + } else if (shapeAttr) { + result.setType(result.getType().cast().clone(shape)); + } + } + }); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestChangeElementOrShape() { + PassRegistration pass( + "test-update-element-or-type", "Changes ShapedType's element or shape"); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -62,6 +62,7 @@ void registerTestAffineLoopParametricTilingPass(); void registerTestAliasAnalysisPass(); void registerTestCallGraphPass(); +void registerTestChangeElementOrShape(); void registerTestConstantFold(); void registerTestConvVectorization(); void registerTestConvertGPUKernelToCubinPass(); @@ -132,6 +133,7 @@ test::registerTestAffineLoopParametricTilingPass(); test::registerTestAliasAnalysisPass(); test::registerTestCallGraphPass(); + test::registerTestChangeElementOrShape(); test::registerTestConstantFold(); #if MLIR_CUDA_CONVERSIONS_ENABLED test::registerTestConvertGPUKernelToCubinPass();