diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -186,8 +186,8 @@ // Allocate stack memory. auto loc = op.getLoc(); Value rank = rewriter.create(loc, arg); - Type i64Ty = rewriter.getI64Type(); - Type memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty); + Type indexTy = rewriter.getIndexType(); + Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); Value mem = rewriter.create(loc, memTy, ValueRange{rank}); // Copy shape extents to stack-allocated memory. @@ -197,15 +197,12 @@ loc, zero, rank, one, llvm::None, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { Value dim = rewriter.create(loc, arg, iv); - Value dimInt = rewriter.create(loc, dim, i64Ty); - rewriter.create(loc, dimInt, mem, ValueRange{iv}); + rewriter.create(loc, dim, mem, ValueRange{iv}); rewriter.create(loc); }); // Load extents to tensor value. - Value extentTensorInt = rewriter.create(loc, mem); - rewriter.replaceOpWithNewOp(op.getOperation(), extentTensorInt, - op.getType()); + rewriter.replaceOpWithNewOp(op.getOperation(), mem); return success(); } diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -40,16 +40,14 @@ // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) func @shape_of_unranked(%arg : tensor<*xf32>) { // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> - // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref + // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] { // CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32> - // CHECK: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64 - // CHECK: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref + // CHECK: store %[[DIM]], %[[SHAPE_MEM]][%[[I]]] : memref // CHECK: } - // CHECK: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref - // CHECK: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor to tensor + // CHECK: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref %shape = shape.shape_of %arg : tensor<*xf32> -> tensor return }