diff --git a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp @@ -25,29 +25,26 @@ // Codegen rewrite: rewriting of subgraphs of ops //===----------------------------------------------------------------------===// -using namespace fir; -using namespace mlir; - #define DEBUG_TYPE "flang-codegen-rewrite" static void populateShape(llvm::SmallVectorImpl &vec, - ShapeOp shape) { + fir::ShapeOp shape) { vec.append(shape.getExtents().begin(), shape.getExtents().end()); } // Operands of fir.shape_shift split into two vectors. static void populateShapeAndShift(llvm::SmallVectorImpl &shapeVec, llvm::SmallVectorImpl &shiftVec, - ShapeShiftOp shift) { - auto endIter = shift.getPairs().end(); - for (auto i = shift.getPairs().begin(); i != endIter;) { + fir::ShapeShiftOp shift) { + for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end(); + i != endIter;) { shiftVec.push_back(*i++); shapeVec.push_back(*i++); } } static void populateShift(llvm::SmallVectorImpl &vec, - ShiftOp shift) { + fir::ShiftOp shift) { vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); } @@ -72,27 +69,26 @@ /// (!fir.ref>, index, index, index, index, index) -> /// !fir.box> /// ``` -class EmboxConversion : public mlir::OpRewritePattern { +class EmboxConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult - matchAndRewrite(EmboxOp embox, + matchAndRewrite(fir::EmboxOp embox, mlir::PatternRewriter &rewriter) const override { - auto shapeVal = embox.getShape(); // If the embox does not include a shape, then do not convert it - if (shapeVal) + if (auto shapeVal = embox.getShape()) return rewriteDynamicShape(embox, rewriter, shapeVal); - if (auto boxTy = embox.getType().dyn_cast()) - if (auto seqTy = boxTy.getEleTy().dyn_cast()) + if (auto boxTy = embox.getType().dyn_cast()) + if (auto seqTy = boxTy.getEleTy().dyn_cast()) if (seqTy.hasConstantShape()) return rewriteStaticShape(embox, rewriter, seqTy); return mlir::failure(); } - mlir::LogicalResult rewriteStaticShape(EmboxOp embox, + mlir::LogicalResult rewriteStaticShape(fir::EmboxOp embox, mlir::PatternRewriter &rewriter, - SequenceType seqTy) const { + fir::SequenceType seqTy) const { auto loc = embox.getLoc(); llvm::SmallVector shapeOpers; auto idxTy = rewriter.getIndexType(); @@ -101,7 +97,7 @@ auto extVal = rewriter.create(loc, idxTy, iAttr); shapeOpers.push_back(extVal); } - auto xbox = rewriter.create( + auto xbox = rewriter.create( loc, embox.getType(), embox.getMemref(), shapeOpers, llvm::None, llvm::None, llvm::None, llvm::None, embox.getTypeparams()); LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); @@ -109,17 +105,17 @@ return mlir::success(); } - mlir::LogicalResult rewriteDynamicShape(EmboxOp embox, + mlir::LogicalResult rewriteDynamicShape(fir::EmboxOp embox, mlir::PatternRewriter &rewriter, mlir::Value shapeVal) const { auto loc = embox.getLoc(); - auto shapeOp = dyn_cast(shapeVal.getDefiningOp()); llvm::SmallVector shapeOpers; llvm::SmallVector shiftOpers; - if (shapeOp) { + if (auto shapeOp = mlir::dyn_cast(shapeVal.getDefiningOp())) { populateShape(shapeOpers, shapeOp); } else { - auto shiftOp = dyn_cast(shapeVal.getDefiningOp()); + auto shiftOp = + mlir::dyn_cast(shapeVal.getDefiningOp()); assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); } @@ -127,7 +123,8 @@ llvm::SmallVector subcompOpers; llvm::SmallVector substrOpers; if (auto s = embox.getSlice()) - if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + if (auto sliceOp = + mlir::dyn_cast_or_null(s.getDefiningOp())) { sliceOpers.assign(sliceOp.getTriples().begin(), sliceOp.getTriples().end()); subcompOpers.assign(sliceOp.getFields().begin(), @@ -135,7 +132,7 @@ substrOpers.assign(sliceOp.getSubstr().begin(), sliceOp.getSubstr().end()); } - auto xbox = rewriter.create( + auto xbox = rewriter.create( loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers, sliceOpers, subcompOpers, substrOpers, embox.getTypeparams()); LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); @@ -156,22 +153,24 @@ /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box>, /// index, index) -> !fir.box> /// ``` -class ReboxConversion : public mlir::OpRewritePattern { +class ReboxConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult - matchAndRewrite(ReboxOp rebox, + matchAndRewrite(fir::ReboxOp rebox, mlir::PatternRewriter &rewriter) const override { auto loc = rebox.getLoc(); llvm::SmallVector shapeOpers; llvm::SmallVector shiftOpers; if (auto shapeVal = rebox.getShape()) { - if (auto shapeOp = dyn_cast(shapeVal.getDefiningOp())) + if (auto shapeOp = mlir::dyn_cast(shapeVal.getDefiningOp())) populateShape(shapeOpers, shapeOp); - else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + else if (auto shiftOp = + mlir::dyn_cast(shapeVal.getDefiningOp())) populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); - else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + else if (auto shiftOp = + mlir::dyn_cast(shapeVal.getDefiningOp())) populateShift(shiftOpers, shiftOp); else return mlir::failure(); @@ -180,7 +179,8 @@ llvm::SmallVector subcompOpers; llvm::SmallVector substrOpers; if (auto s = rebox.getSlice()) - if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + if (auto sliceOp = + mlir::dyn_cast_or_null(s.getDefiningOp())) { sliceOpers.append(sliceOp.getTriples().begin(), sliceOp.getTriples().end()); subcompOpers.append(sliceOp.getFields().begin(), @@ -189,7 +189,7 @@ sliceOp.getSubstr().end()); } - auto xRebox = rewriter.create( + auto xRebox = rewriter.create( loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers, sliceOpers, subcompOpers, substrOpers); LLVM_DEBUG(llvm::dbgs() @@ -212,22 +212,24 @@ /// (!fir.ref>, index, index, index, index, index, index) -> /// !fir.ref /// ``` -class ArrayCoorConversion : public mlir::OpRewritePattern { +class ArrayCoorConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult - matchAndRewrite(ArrayCoorOp arrCoor, + matchAndRewrite(fir::ArrayCoorOp arrCoor, mlir::PatternRewriter &rewriter) const override { auto loc = arrCoor.getLoc(); llvm::SmallVector shapeOpers; llvm::SmallVector shiftOpers; if (auto shapeVal = arrCoor.getShape()) { - if (auto shapeOp = dyn_cast(shapeVal.getDefiningOp())) + if (auto shapeOp = mlir::dyn_cast(shapeVal.getDefiningOp())) populateShape(shapeOpers, shapeOp); - else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + else if (auto shiftOp = + mlir::dyn_cast(shapeVal.getDefiningOp())) populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); - else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + else if (auto shiftOp = + mlir::dyn_cast(shapeVal.getDefiningOp())) populateShift(shiftOpers, shiftOp); else return mlir::failure(); @@ -235,7 +237,8 @@ llvm::SmallVector sliceOpers; llvm::SmallVector subcompOpers; if (auto s = arrCoor.getSlice()) - if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + if (auto sliceOp = + mlir::dyn_cast_or_null(s.getDefiningOp())) { sliceOpers.append(sliceOp.getTriples().begin(), sliceOp.getTriples().end()); subcompOpers.append(sliceOp.getFields().begin(), @@ -244,7 +247,7 @@ "Don't allow substring operations on array_coor. This " "restriction may be lifted in the future."); } - auto xArrCoor = rewriter.create( + auto xArrCoor = rewriter.create( loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers, sliceOpers, subcompOpers, arrCoor.getIndices(), arrCoor.getTypeparams()); @@ -255,20 +258,22 @@ } }; -class CodeGenRewrite : public CodeGenRewriteBase { +class CodeGenRewrite : public fir::CodeGenRewriteBase { public: void runOnOperation() override final { auto op = getOperation(); auto &context = getContext(); mlir::OpBuilder rewriter(&context); mlir::ConversionTarget target(context); - target.addLegalDialect(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addDynamicallyLegalOp([](EmboxOp embox) { - return !(embox.getShape() || - embox.getType().cast().getEleTy().isa()); + target.addLegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](fir::EmboxOp embox) { + return !(embox.getShape() || embox.getType() + .cast() + .getEleTy() + .isa()); }); mlir::RewritePatternSet patterns(&context); patterns.insert(