diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -38,7 +38,7 @@ precision(precision) {} LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op, - PatternRewriter &rewrite) const override { + PatternRewriter &rewriter) const override { Location location = op->getLoc(); if (op->hasAttr(op.getTf32EnabledAttrName()) || @@ -53,8 +53,10 @@ return emitError(location, "TF32x3 is not supported at the moment " "for nvgpu.mma.sync on f32 datatype"); - if (precision == MmaSyncF32Lowering::TF32) - op.setTf32EnabledAttr(rewrite.getUnitAttr()); + if (precision == MmaSyncF32Lowering::TF32) { + rewriter.updateRootInPlace( + op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); }); + } return success(); }