diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td --- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td @@ -83,6 +83,7 @@ (`sgprOffset` $sgprOffset^)? `:` type($memref) `,` type($indices) `->` type($value) }]; + let hasCanonicalizer = 1; let hasVerifier = 1; } @@ -124,6 +125,7 @@ (`sgprOffset` $sgprOffset^)? `:` type($value) `->` type($memref) `,` type($indices) }]; + let hasCanonicalizer = 1; let hasVerifier = 1; } @@ -162,6 +164,7 @@ (`sgprOffset` $sgprOffset^)? `:` type($value) `->` type($memref) `,` type($indices) }]; + let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -12,14 +12,19 @@ #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" +#include + using namespace mlir; using namespace mlir::amdgpu; @@ -62,6 +67,97 @@ return verifyRawBufferOp(*this); } +static Optional getConstantUint32(Value v) { + APInt cst; + if (!v.getType().isInteger(32)) + return None; + if (matchPattern(v, m_ConstantInt(&cst))) + return cst.getZExtValue(); + return None; +} + +template +static LogicalResult staticallyOutOfBounds(OpType op) { + if (!op.getBoundsCheck()) + return failure(); + MemRefType bufferType = op.getMemref().getType(); + if (!bufferType.hasStaticShape()) + return failure(); + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(bufferType, strides, offset))) + return failure(); + int64_t result = offset + op.getIndexOffset().value_or(0); + if (op.getSgprOffset()) { + Optional sgprOffset = getConstantUint32(op.getSgprOffset()); + if (!sgprOffset) + return failure(); + result += *sgprOffset; + } + if (strides.size() != op.getIndices().size()) + return failure(); + int64_t indexVal = 0; + for (auto pair : llvm::zip(strides, op.getIndices())) { + int64_t stride = std::get<0>(pair); + Value idx = std::get<1>(pair); + Optional idxVal = getConstantUint32(idx); + if (!idxVal) + return failure(); + indexVal += stride * idxVal.value(); + } + result += indexVal; + if (result > std::numeric_limits::max()) + // Overflow means don't drop + return failure(); + return success(result >= bufferType.getNumElements()); +} + +namespace { +struct RemoveStaticallyOobBufferLoads final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RawBufferLoadOp op, + PatternRewriter &rw) const override { + if (succeeded(staticallyOutOfBounds(op))) { + Type loadType = op.getResult().getType(); + rw.replaceOpWithNewOp(op, loadType, + rw.getZeroAttr(loadType)); + return success(); + } + return failure(); + } +}; + +template +struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override { + if (succeeded(staticallyOutOfBounds(op))) { + rw.eraseOp(op); + return success(); + } + return failure(); + } +}; +} // end namespace + +void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + +void RawBufferAtomicFaddOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // MFMAOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt --- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRAMDGPUIncGen LINK_LIBS PUBLIC + MLIRArithmeticDialect MLIRIR MLIRSideEffectInterfaces ) diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -0,0 +1,132 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + +// CHECK-LABEL: func @known_oob_load +func.func @known_oob_load(%arg0: memref<4xf32>) -> f32 { + // CHECK: %[[zero:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: return %[[zero]] + %c4_i32 = arith.constant 4 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %arg0[%c4_i32] : memref<4xf32>, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @known_oob_load_2d +func.func @known_oob_load_2d(%arg0: memref<4x4xf32>) -> f32 { + // CHECK: %[[zero:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: return %[[zero]] + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %arg0[%c4_i32, %c0_i32] : memref<4x4xf32>, i32, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @known_oob_load_2d_on_last +func.func @known_oob_load_2d_on_last(%arg0: memref<4x4xf32>) -> f32 { + // CHECK: %[[zero:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: return %[[zero]] + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %arg0[%c0_i32, %c16_i32] : memref<4x4xf32>, i32, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @known_oob_load_index +func.func @known_oob_load_index(%arg0: memref<4xf32>) -> f32 { + // CHECK: %[[zero:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: return %[[zero]] + %c0_i32 = arith.constant 0 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true, indexOffset = 4 : i32} %arg0[%c0_i32] : memref<4xf32>, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @known_oob_load_sgproffset +func.func @known_oob_load_sgproffset(%arg0: memref<4xf32>) -> f32 { + // CHECK: %[[zero:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: return %[[zero]] + %c2_i32 = arith.constant 2 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %arg0[%c2_i32] sgprOffset %c2_i32 : memref<4xf32>, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @unknown_load +func.func @unknown_load(%arg0: memref<4xf32>, %arg1: i32) -> f32 { + // CHECK: %[[loaded:.*]] = amdgpu.raw_buffer_load + // CHECK: return %[[loaded]] + %c4_i32 = arith.constant 4 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %arg0[%arg1] sgprOffset %c4_i32 : memref<4xf32>, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @unknown_load_sgproffset +func.func @unknown_load_sgproffset(%arg0: memref<4xf32>, %arg1: i32) -> f32 { + // CHECK: %[[loaded:.*]] = amdgpu.raw_buffer_load + // CHECK: return %[[loaded]] + %c4_i32 = arith.constant 4 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %arg0[%c4_i32] sgprOffset %arg1 : memref<4xf32>, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @unranked +func.func @unranked(%arg0: memref) -> f32 { + // CHECK: %[[loaded:.*]] = amdgpu.raw_buffer_load + // CHECK: return %[[loaded]] + %c4_i32 = arith.constant 4 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %arg0[%c4_i32] : memref, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @no_oob_check +func.func @no_oob_check(%arg0: memref<4xf32>) -> f32 { + // CHECK: %[[loaded:.*]] = amdgpu.raw_buffer_load + // CHECK: return %[[loaded]] + %c4_i32 = arith.constant 4 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = false} %arg0[%c4_i32] : memref<4xf32>, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @in_bounds_overall +func.func @in_bounds_overall(%arg0: memref<4x4xf32>) -> f32 { + // CHECK: %[[loaded:.*]] = amdgpu.raw_buffer_load + // CHECK: return %[[loaded]] + %c0_i32 = arith.constant 0 : i32 + %c15_i32 = arith.constant 15 : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %arg0[%c0_i32, %c15_i32] : memref<4x4xf32>, i32, i32 -> f32 + func.return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @dead_store +func.func @dead_store(%arg0: memref<4xf32>, %arg1: f32) { + // CHECK-NOT: amdgpu.raw_buffer_store + %c4_i32 = arith.constant 4 : i32 + amdgpu.raw_buffer_store {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32 + func.return +} + +// ----- + +// CHECK-LABEL: func @dead_atomic_add +func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) { + // CHECK-NOT: amdgpu.raw_buffer_atomic_fadd + %c4_i32 = arith.constant 4 : i32 + amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32 + func.return +}