diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -874,14 +874,17 @@ def GPU_ShuffleModeAttr : EnumAttr; -def I32OrF32 : TypeConstraint, - "i32 or f32">; +def I32I64F32OrF64 : TypeConstraint, + "i32, i64, f32 or f64">; def GPU_ShuffleOp : GPU_Op< "shuffle", [Pure, AllTypesMatch<["value", "shuffleResult"]>]>, - Arguments<(ins I32OrF32:$value, I32:$offset, I32:$width, + Arguments<(ins I32I64F32OrF64:$value, I32:$offset, I32:$width, GPU_ShuffleModeAttr:$mode)>, - Results<(outs I32OrF32:$shuffleResult, I1:$valid)> { + Results<(outs I32I64F32OrF64:$shuffleResult, I1:$valid)> { let summary = "Shuffles values within a subgroup."; let description = [{ The "shuffle" op moves values to a different invocation within the same diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h @@ -53,12 +53,16 @@ /// mapped to sequential loops. std::unique_ptr> createGpuMapParallelLoopsPass(); +/// Collect a set of patterns to rewrite shuffle ops within the GPU dialect. +void populateGpuShufflePatterns(RewritePatternSet &patterns); + /// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect. void populateGpuAllReducePatterns(RewritePatternSet &patterns); /// Collect all patterns to rewrite ops within the GPU dialect. inline void populateGpuRewritePatterns(RewritePatternSet &patterns) { populateGpuAllReducePatterns(patterns); + populateGpuShufflePatterns(patterns); } namespace gpu { diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -49,6 +49,7 @@ Transforms/KernelOutlining.cpp Transforms/MemoryPromotion.cpp Transforms/ParallelLoopMapper.cpp + Transforms/ShuffleRewriter.cpp Transforms/SerializeToBlob.cpp Transforms/SerializeToCubin.cpp Transforms/SerializeToHsaco.cpp diff --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp @@ -0,0 +1,98 @@ +//===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements in-dialect rewriting of the shuffle op for types other +// than i32 and f32, in particular i64 and f64. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct GpuShuffleRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + void initialize() { + // Required as the pattern will replace the Op with 2 additional ShuffleOps. + setHasBoundedRewriteRecursion(); + } + + LogicalResult matchAndRewrite(gpu::ShuffleOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto value = op.getValue(); + auto valueType = value.getType(); + auto valueLoc = value.getLoc(); + auto i32 = rewriter.getI32Type(); + auto i64 = rewriter.getI64Type(); + + // If the type of the value is either i32 or f32, the op is already valid. + if (valueType.getIntOrFloatBitWidth() == 32) + return failure(); + + Value lo, hi; + + // Float types must be converted to i64 to extract the bits. + if (isa(valueType)) + value = rewriter.create(valueLoc, i64, value); + + // Get the low bits by trunc(value). + lo = rewriter.create(valueLoc, i32, value); + + // Get the high bits by trunc(value >> 32). + auto c32 = rewriter.create( + valueLoc, rewriter.getIntegerAttr(i64, 32)); + hi = rewriter.create(valueLoc, value, c32); + hi = rewriter.create(valueLoc, i32, hi); + + // Shuffle the values. + ValueRange loRes = + rewriter + .create(op.getLoc(), lo, op.getOffset(), + op.getWidth(), op.getMode()) + .getResults(); + ValueRange hiRes = + rewriter + .create(op.getLoc(), hi, op.getOffset(), + op.getWidth(), op.getMode()) + .getResults(); + + // Convert lo back to i64. + lo = rewriter.create(valueLoc, i64, loRes[0]); + + // Convert hi back to i64. + hi = rewriter.create(valueLoc, i64, hiRes[0]); + hi = rewriter.create(valueLoc, hi, c32); + + // Obtain the shuffled bits hi | lo. + value = rewriter.create(loc, hi, lo); + + // Convert the value back to float. + if (isa(valueType)) + value = rewriter.create(valueLoc, valueType, value); + + // Obtain the shuffle validity by combining both validities. + auto validity = rewriter.create(loc, loRes[1], hiRes[1]); + + // Replace the op. + rewriter.replaceOp(op, {value, validity}); + return success(); + } +}; +} // namespace + +void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -318,7 +318,7 @@ // ----- func.func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) { - // expected-error@+1 {{operand #0 must be i32 or f32}} + // expected-error@+1 {{operand #0 must be i32, i64, f32 or f64}} %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : index return } diff --git a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s + +module { + // CHECK-LABEL: func.func @shuffleF64 + // CHECK-SAME: (%[[SZ:.*]]: index, %[[VALUE:.*]]: f64, %[[OFF:.*]]: i32, %[[WIDTH:.*]]: i32, %[[MEM:.*]]: memref) { + func.func @shuffleF64(%sz : index, %value: f64, %offset: i32, %width: i32, %mem: memref) { + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz) + threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) { + // CHECK: %[[INTVAL:.*]] = arith.bitcast %[[VALUE]] : f64 to i64 + // CHECK-NEXT: %[[LO:.*]] = arith.trunci %[[INTVAL]] : i64 to i32 + // CHECK-NEXT: %[[HI64:.*]] = arith.shrui %[[INTVAL]], %[[C32:.*]] : i64 + // CHECK-NEXT: %[[HI:.*]] = arith.trunci %[[HI64]] : i64 to i32 + // CHECK-NEXT: %[[SH1:.*]], %[[V1:.*]] = gpu.shuffle xor %[[LO]], %[[OFF]], %[[WIDTH]] : i32 + // CHECK-NEXT: %[[SH2:.*]], %[[V2:.*]] = gpu.shuffle xor %[[HI]], %[[OFF]], %[[WIDTH]] : i32 + // CHECK-NEXT: %[[LOSH:.*]] = arith.extui %[[SH1]] : i32 to i64 + // CHECK-NEXT: %[[HISHTMP:.*]] = arith.extui %[[SH2]] : i32 to i64 + // CHECK-NEXT: %[[HISH:.*]] = arith.shli %[[HISHTMP]], %[[C32]] : i64 + // CHECK-NEXT: %[[SHFLINT:.*]] = arith.ori %[[HISH]], %[[LOSH]] : i64 + // CHECK-NEXT: = arith.bitcast %[[SHFLINT]] : i64 to f64 + %shfl, %pred = gpu.shuffle xor %value, %offset, %width : f64 + memref.store %shfl, %mem[] : memref + gpu.terminator + } + return + } +} + +// ----- + +module { + // CHECK-LABEL: func.func @shuffleI64 + // CHECK-SAME: (%[[SZ:.*]]: index, %[[VALUE:.*]]: i64, %[[OFF:.*]]: i32, %[[WIDTH:.*]]: i32, %[[MEM:.*]]: memref) { + func.func @shuffleI64(%sz : index, %value: i64, %offset: i32, %width: i32, %mem: memref) { + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz) + threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) { + // CHECK: %[[LO:.*]] = arith.trunci %[[VALUE]] : i64 to i32 + // CHECK-NEXT: %[[HI64:.*]] = arith.shrui %[[VALUE]], %[[C32:.*]] : i64 + // CHECK-NEXT: %[[HI:.*]] = arith.trunci %[[HI64]] : i64 to i32 + // CHECK-NEXT: %[[SH1:.*]], %[[V1:.*]] = gpu.shuffle xor %[[LO]], %[[OFF]], %[[WIDTH]] : i32 + // CHECK-NEXT: %[[SH2:.*]], %[[V2:.*]] = gpu.shuffle xor %[[HI]], %[[OFF]], %[[WIDTH]] : i32 + // CHECK-NEXT: %[[LOSH:.*]] = arith.extui %[[SH1]] : i32 to i64 + // CHECK-NEXT: %[[HISHTMP:.*]] = arith.extui %[[SH2]] : i32 to i64 + // CHECK-NEXT: %[[HISH:.*]] = arith.shli %[[HISHTMP]], %[[C32]] : i64 + // CHECK-NEXT: %[[SHFLINT:.*]] = arith.ori %[[HISH]], %[[LOSH]] : i64 + %shfl, %pred = gpu.shuffle xor %value, %offset, %width : i64 + memref.store %shfl, %mem[] : memref + gpu.terminator + } + return + } +}