diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -834,7 +834,13 @@ bodyRegion->push_back(new Block); Block &bodyBlock = bodyRegion->front(); bodyBlock.addArgument(builder.getIndexType()); - bodyBlock.addArgument(SizeType::get(builder.getContext())); + + Type elementType; + if (auto tensorType = shape.getType().dyn_cast()) + elementType = tensorType.getElementType(); + else + elementType = SizeType::get(builder.getContext()); + bodyBlock.addArgument(elementType); for (Type initValType : initVals.getTypes()) { bodyBlock.addArgument(initValType); diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp --- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -9,6 +9,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -32,14 +33,18 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op, PatternRewriter &rewriter) const { auto loc = op.getLoc(); - Value init = rewriter.create(loc, rewriter.getIndexAttr(1)); + Type valueType = op.getResult().getType(); + Value init = op.getDialect() + ->materializeConstant(rewriter, rewriter.getIndexAttr(1), + valueType, loc) + ->getResult(0); ReduceOp reduce = rewriter.create(loc, op.shape(), init); // Generate reduce operator. Block *body = reduce.getBody(); OpBuilder b = OpBuilder::atBlockEnd(body); - Value product = b.create(loc, b.getType(), - body->getArgument(1), body->getArgument(2)); + Value product = b.create(loc, valueType, body->getArgument(1), + body->getArgument(2)); b.create(loc, product); rewriter.replaceOp(op, reduce.result()); @@ -60,7 +65,7 @@ populateShapeRewritePatterns(&ctx, patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); diff --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir --- a/mlir/test/Dialect/Shape/shape-to-shape.mlir +++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir @@ -14,3 +14,18 @@ // CHECK: } // CHECK: return [[NUM_ELEMENTS]] : !shape.size +// ----- + +// CHECK-LABEL: func @num_elements_to_reduce_on_index +// CHECK-SAME: ([[ARG:%.*]]: tensor) -> index +func @num_elements_to_reduce_on_index(%shape : tensor) -> index { + %num_elements = shape.num_elements %shape : tensor -> index + return %num_elements : index +} +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : tensor -> index +// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: index, [[ACC:%.*]]: index +// CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]] +// CHECK: shape.yield [[NEW_ACC]] : index +// CHECK: } +// CHECK: return [[NUM_ELEMENTS]] : index