diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -326,6 +326,7 @@ }]; let verifier = [{ return ::verify(*this); }]; + let hasFolder = 1; } def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -701,6 +701,18 @@ return success(); } +OpFoldResult MulOp::fold(ArrayRef operands) { + auto lhs = operands[0].dyn_cast_or_null(); + if (!lhs) + return nullptr; + auto rhs = operands[1].dyn_cast_or_null(); + if (!rhs) + return nullptr; + APInt folded = lhs.getValue() * rhs.getValue(); + Type indexTy = IndexType::get(getContext()); + return IntegerAttr::get(indexTy, folded); +} + //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -734,3 +734,43 @@ %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape return %result : i1 } + +// ----- + +// Fold `mul` for constant sizes. +// CHECK-LABEL: @fold_mul_size +func @fold_mul_size() -> !shape.size { + // CHECK: %[[RESULT:.*]] = shape.const_size 6 + // CHECK: return %[[RESULT]] : !shape.size + %c2 = shape.const_size 2 + %c3 = shape.const_size 3 + %result = shape.mul %c2, %c3 : !shape.size, !shape.size -> !shape.size + return %result : !shape.size +} + +// ----- + +// Fold `mul` for constant indices. +// CHECK-LABEL: @fold_mul_index +func @fold_mul_index() -> index { + // CHECK: %[[RESULT:.*]] = constant 6 : index + // CHECK: return %[[RESULT]] : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %result = shape.mul %c2, %c3 : index, index -> index + return %result : index +} + +// ----- + +// Fold `mul` for mixed constants. +// CHECK-LABEL: @fold_mul_mixed +func @fold_mul_mixed() -> !shape.size { + // CHECK: %[[RESULT:.*]] = shape.const_size 6 + // CHECK: return %[[RESULT]] : !shape.size + %c2 = shape.const_size 2 + %c3 = constant 3 : index + %result = shape.mul %c2, %c3 : !shape.size, index -> !shape.size + return %result : !shape.size +} +