diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -172,39 +172,37 @@ ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { ShapeOfOp::Adaptor transformed(operands); - auto tensorVal = transformed.arg(); - auto tensorTy = tensorVal.getType(); + Value arg = transformed.arg(); + Type argTy = arg.getType(); // For ranked tensors `shape_of` lowers to `std` and the pattern can be // found in the corresponding pass. - if (tensorTy.isa()) + if (argTy.isa()) return failure(); // Allocate stack memory. auto loc = op.getLoc(); - auto rankVal = rewriter.create(loc, tensorVal); - auto i64Ty = rewriter.getI64Type(); - auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty); - auto memVal = rewriter.create(loc, memTy, ValueRange({rankVal})); + Value rank = rewriter.create(loc, arg); + Type i64Ty = rewriter.getI64Type(); + Type memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty); + Value mem = rewriter.create(loc, memTy, ValueRange{rank}); // Copy shape extents to stack-allocated memory. - auto zeroVal = rewriter.create(loc, 0); - auto oneVal = rewriter.create(loc, 1); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); rewriter.create( - loc, zeroVal, rankVal, oneVal, llvm::None, - [&](OpBuilder &b, Location loc, Value iVal, ValueRange args) { - auto dimVal = rewriter.create(loc, tensorVal, iVal); - auto dimIntVal = rewriter.create(loc, dimVal, i64Ty); - rewriter.create(loc, dimIntVal, memVal, ValueRange{iVal}); + loc, zero, rank, one, llvm::None, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + Value dim = rewriter.create(loc, arg, iv); + Value dimInt = rewriter.create(loc, dim, i64Ty); + rewriter.create(loc, dimInt, mem, ValueRange{iv}); rewriter.create(loc); }); // Load extents to tensor value. - auto shapeIntVal = rewriter.create(loc, memVal); - auto indexTy = rewriter.getIndexType(); - auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); - rewriter.replaceOpWithNewOp(op.getOperation(), shapeIntVal, - shapeTy); + Value extentTensorInt = rewriter.create(loc, mem); + rewriter.replaceOpWithNewOp(op.getOperation(), extentTensorInt, + op.getType()); return success(); }