diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -177,7 +177,7 @@ private: LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::AllocOp allocOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -191,7 +191,7 @@ private: LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -343,18 +343,16 @@ } LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( - Operation *op, ArrayRef operands, + gpu::AllocOp allocOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - auto allocOp = cast(op); MemRefType memRefType = allocOp.getType(); - if (failed(areAllLLVMTypes(op, operands, rewriter)) || + if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) || !isSupportedMemRefType(memRefType) || - failed( - isAsyncWithOneDependency(rewriter, cast(op)))) + failed(isAsyncWithOneDependency(rewriter, allocOp))) return failure(); - auto loc = op->getLoc(); + auto loc = allocOp.getLoc(); // Get shape of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. @@ -367,7 +365,8 @@ // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); - auto adaptor = gpu::AllocOpAdaptor(operands, op->getAttrDictionary()); + auto adaptor = gpu::AllocOpAdaptor( + operands, allocOp.getOperation()->getAttrDictionary()); auto stream = adaptor.asyncDependencies().front(); Value allocatedPtr = allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0); @@ -381,29 +380,29 @@ auto memRefDescriptor = this->createMemRefDescriptor( loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); - rewriter.replaceOp(op, {memRefDescriptor, stream}); + rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); return success(); } LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( - Operation *op, ArrayRef operands, + gpu::DeallocOp deallocOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, operands, rewriter)) || - failed( - isAsyncWithOneDependency(rewriter, cast(op)))) + if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) || + failed(isAsyncWithOneDependency(rewriter, deallocOp))) return failure(); - Location loc = op->getLoc(); + Location loc = deallocOp.getLoc(); - auto adaptor = gpu::DeallocOpAdaptor(operands, op->getAttrDictionary()); + auto adaptor = gpu::DeallocOpAdaptor( + operands, deallocOp.getOperation()->getAttrDictionary()); Value pointer = MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc); auto casted = rewriter.create(loc, llvmPointerType, pointer); Value stream = adaptor.asyncDependencies().front(); deallocCallBuilder.create(loc, rewriter, {casted, stream}); - rewriter.replaceOp(op, {stream}); + rewriter.replaceOp(deallocOp, {stream}); return success(); }