From db8fde409dac58880421407298b19a914f3ffa65 Mon Sep 17 00:00:00 2001 From: Or Biri Date: Sat, 23 Nov 2024 15:53:33 +0200 Subject: [PATCH 1/2] [CIR][ThroughMLIR] Lower `cir.bool` to i1 --- .../ThroughMLIR/LowerCIRLoopToSCF.cpp | 5 +- .../Lowering/ThroughMLIR/LowerCIRToMLIR.cpp | 123 ++++++++++++------ clang/test/CIR/CodeGen/globals.cpp | 4 + clang/test/CIR/Lowering/ThroughMLIR/bool.cir | 5 +- .../test/CIR/Lowering/ThroughMLIR/branch.cir | 14 +- clang/test/CIR/Lowering/ThroughMLIR/cast.cir | 32 ++--- clang/test/CIR/Lowering/ThroughMLIR/cmp.cpp | 15 +-- clang/test/CIR/Lowering/ThroughMLIR/doWhile.c | 12 +- clang/test/CIR/Lowering/ThroughMLIR/if.c | 16 +-- .../test/CIR/Lowering/ThroughMLIR/tenary.cir | 6 +- clang/test/CIR/Lowering/ThroughMLIR/while.c | 12 +- 11 files changed, 128 insertions(+), 116 deletions(-) diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp index 16252e1058cd..d3cccda6cdd7 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp @@ -337,11 +337,8 @@ class CIRConditionOpLowering auto *parentOp = op->getParentOp(); return llvm::TypeSwitch(parentOp) .Case([&](auto) { - auto condition = adaptor.getCondition(); - auto i1Condition = rewriter.create( - op.getLoc(), rewriter.getI1Type(), condition); rewriter.replaceOpWithNewOp( - op, i1Condition, parentOp->getOperands()); + op, adaptor.getCondition(), parentOp->getOperands()); return mlir::success(); }) .Default([](auto) { return mlir::failure(); }); diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index cf2980d49410..30b282b9e3ef 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -35,6 +35,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -105,6 +106,54 @@ class CIRCallOpLowering : public mlir::OpConversionPattern { } }; +/// Given a type convertor and a data layout, convert the given type to a type +/// that is suitable for memory operations. For example, this can be used to +/// lower cir.bool accesses to i8. +static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter, + mlir::Type type) { + // TODO(cir): Handle other types similarly to clang's codegen + // convertTypeForMemory + if (isa(type)) { + // TODO: Use datalayout to get the size of bool + return mlir::IntegerType::get(type.getContext(), 8); + } + + return converter.convertType(type); +} + +/// Emits the value from memory as expected by its users. Should be called when +/// the memory represetnation of a CIR type is not equal to its scalar +/// representation. +static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter, + cir::LoadOp op, mlir::Value value) { + + // TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory + if (isa(op.getResult().getType())) { + // Create trunc of value from i8 to i1 + // TODO: Use datalayout to get the size of bool + assert(value.getType().isInteger(8)); + return createIntCast(rewriter, value, rewriter.getI1Type()); + } + + return value; +} + +/// Emits a value to memory with the expected scalar type. Should be called when +/// the memory represetnation of a CIR type is not equal to its scalar +/// representation. +static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter, + cir::StoreOp op, mlir::Value value) { + + // TODO(cir): Handle other types similarly to clang's codegen EmitToMemory + if (isa(op.getValue().getType())) { + // Create zext of value from i1 to i8 + // TODO: Use datalayout to get the size of bool + return createIntCast(rewriter, value, rewriter.getI8Type()); + } + + return value; +} + class CIRAllocaOpLowering : public mlir::OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -112,8 +161,9 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern { mlir::LogicalResult matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getAllocaType(); - auto mlirType = getTypeConverter()->convertType(type); + + mlir::Type mlirType = + convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType()); // FIXME: Some types can not be converted yet (e.g. struct) if (!mlirType) @@ -174,12 +224,20 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern { mlir::Value base; SmallVector indices; SmallVector eraseList; + mlir::memref::LoadOp newLoad; if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList, rewriter)) { - rewriter.replaceOpWithNewOp(op, base, indices); + newLoad = + rewriter.create(op.getLoc(), base, indices); + // rewriter.replaceOpWithNewOp(op, base, indices); eraseIfSafe(op.getAddr(), base, eraseList, rewriter); } else - rewriter.replaceOpWithNewOp(op, adaptor.getAddr()); + newLoad = + rewriter.create(op.getLoc(), adaptor.getAddr()); + + // Convert adapted result to its original type if needed. + mlir::Value result = emitFromMemory(rewriter, op, newLoad.getResult()); + rewriter.replaceOp(op, result); return mlir::LogicalResult::success(); } }; @@ -194,13 +252,16 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern { mlir::Value base; SmallVector indices; SmallVector eraseList; + + // Convert adapted value to its memory type if needed. + mlir::Value value = emitToMemory(rewriter, op, adaptor.getValue()); if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList, rewriter)) { - rewriter.replaceOpWithNewOp(op, adaptor.getValue(), - base, indices); + rewriter.replaceOpWithNewOp(op, value, base, + indices); eraseIfSafe(op.getAddr(), base, eraseList, rewriter); } else - rewriter.replaceOpWithNewOp(op, adaptor.getValue(), + rewriter.replaceOpWithNewOp(op, value, adaptor.getAddr()); return mlir::LogicalResult::success(); } @@ -744,29 +805,20 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern { mlir::ConversionPatternRewriter &rewriter) const override { auto type = op.getLhs().getType(); - mlir::Value mlirResult; - if (auto ty = mlir::dyn_cast(type)) { auto kind = convertCmpKindToCmpIPredicate(op.getKind(), ty.isSigned()); - mlirResult = rewriter.create( - op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOpWithNewOp( + op, kind, adaptor.getLhs(), adaptor.getRhs()); } else if (auto ty = mlir::dyn_cast(type)) { auto kind = convertCmpKindToCmpFPredicate(op.getKind()); - mlirResult = rewriter.create( - op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOpWithNewOp( + op, kind, adaptor.getLhs(), adaptor.getRhs()); } else if (auto ty = mlir::dyn_cast(type)) { llvm_unreachable("pointer comparison not supported yet"); } else { return op.emitError() << "unsupported type for CmpOp: " << type; } - // MLIR comparison ops return i1, but cir::CmpOp returns the same type as - // the LHS value. Since this return value can be used later, we need to - // restore the type with the extension below. - auto mlirResultTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, mlirResultTy, - mlirResult); - return mlir::LogicalResult::success(); } }; @@ -826,12 +878,8 @@ struct CIRBrCondOpLowering : public mlir::OpConversionPattern { mlir::LogicalResult matchAndRewrite(cir::BrCondOp brOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - - auto condition = adaptor.getCond(); - auto i1Condition = rewriter.create( - brOp.getLoc(), rewriter.getI1Type(), condition); rewriter.replaceOpWithNewOp( - brOp, i1Condition.getResult(), brOp.getDestTrue(), + brOp, adaptor.getCond(), brOp.getDestTrue(), adaptor.getDestOperandsTrue(), brOp.getDestFalse(), adaptor.getDestOperandsFalse()); @@ -847,16 +895,13 @@ class CIRTernaryOpLowering : public mlir::OpConversionPattern { matchAndRewrite(cir::TernaryOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { rewriter.setInsertionPoint(op); - auto condition = adaptor.getCond(); - auto i1Condition = rewriter.create( - op.getLoc(), rewriter.getI1Type(), condition); SmallVector resultTypes; if (mlir::failed(getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) return mlir::failure(); auto ifOp = rewriter.create(op.getLoc(), resultTypes, - i1Condition.getResult(), true); + adaptor.getCond(), true); auto *thenBlock = &ifOp.getThenRegion().front(); auto *elseBlock = &ifOp.getElseRegion().front(); rewriter.inlineBlockBefore(&op.getTrueRegion().front(), thenBlock, @@ -893,11 +938,8 @@ class CIRIfOpLowering : public mlir::OpConversionPattern { mlir::LogicalResult matchAndRewrite(cir::IfOp ifop, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - auto condition = adaptor.getCondition(); - auto i1Condition = rewriter.create( - ifop->getLoc(), rewriter.getI1Type(), condition); auto newIfOp = rewriter.create( - ifop->getLoc(), ifop->getResultTypes(), i1Condition); + ifop->getLoc(), ifop->getResultTypes(), adaptor.getCondition()); auto *thenBlock = rewriter.createBlock(&newIfOp.getThenRegion()); rewriter.inlineBlockBefore(&ifop.getThenRegion().front(), thenBlock, thenBlock->end()); @@ -924,7 +966,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern { mlir::OpBuilder b(moduleOp.getContext()); const auto CIRSymType = op.getSymType(); - auto convertedType = getTypeConverter()->convertType(CIRSymType); + auto convertedType = convertTypeForMemory(*getTypeConverter(), CIRSymType); if (!convertedType) return mlir::failure(); auto memrefType = dyn_cast(convertedType); @@ -1170,19 +1212,14 @@ class CIRCastOpLowering : public mlir::OpConversionPattern { return mlir::success(); } case CIR::float_to_bool: { - auto dstTy = mlir::cast(op.getType()); - auto newDstType = convertTy(dstTy); auto kind = mlir::arith::CmpFPredicate::UNE; // Check if float is not equal to zero. auto zeroFloat = rewriter.create( op.getLoc(), src.getType(), mlir::FloatAttr::get(src.getType(), 0.0)); - // Extend comparison result to either bool (C++) or int (C). - mlir::Value cmpResult = rewriter.create( - op.getLoc(), kind, src, zeroFloat); - rewriter.replaceOpWithNewOp(op, newDstType, - cmpResult); + rewriter.replaceOpWithNewOp(op, kind, src, + zeroFloat); return mlir::success(); } case CIR::bool_to_int: { @@ -1330,7 +1367,7 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, static mlir::TypeConverter prepareTypeConverter() { mlir::TypeConverter converter; converter.addConversion([&](cir::PointerType type) -> mlir::Type { - auto ty = converter.convertType(type.getPointee()); + auto ty = convertTypeForMemory(converter, type.getPointee()); // FIXME: The pointee type might not be converted (e.g. struct) if (!ty) return nullptr; @@ -1350,7 +1387,7 @@ static mlir::TypeConverter prepareTypeConverter() { mlir::IntegerType::SignednessSemantics::Signless); }); converter.addConversion([&](cir::BoolType type) -> mlir::Type { - return mlir::IntegerType::get(type.getContext(), 8); + return mlir::IntegerType::get(type.getContext(), 1); }); converter.addConversion([&](cir::SingleType type) -> mlir::Type { return mlir::FloatType::getF32(type.getContext()); diff --git a/clang/test/CIR/CodeGen/globals.cpp b/clang/test/CIR/CodeGen/globals.cpp index ca8161b1cb8f..3b91bacfed22 100644 --- a/clang/test/CIR/CodeGen/globals.cpp +++ b/clang/test/CIR/CodeGen/globals.cpp @@ -20,6 +20,10 @@ void use_global() { int li = a; } +bool bool_global() { + return e; +} + void use_global_string() { unsigned char c = s2[0]; } diff --git a/clang/test/CIR/Lowering/ThroughMLIR/bool.cir b/clang/test/CIR/Lowering/ThroughMLIR/bool.cir index 408cac97ee41..5383477255aa 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/bool.cir +++ b/clang/test/CIR/Lowering/ThroughMLIR/bool.cir @@ -14,8 +14,9 @@ module { // MLIR: func @foo() { // MLIR: [[Value:%[a-z0-9]+]] = memref.alloca() {alignment = 1 : i64} : memref -// MLIR: = arith.constant 1 : i8 -// MLIR: memref.store {{.*}}, [[Value]][] : memref +// MLIR: %[[CONST:.*]] = arith.constant true +// MLIR: %[[BOOL_TO_MEM:.*]] = arith.extui %[[CONST]] : i1 to i8 +// MLIR-NEXT: memref.store %[[BOOL_TO_MEM]], [[Value]][] : memref // return // LLVM: = alloca i8, i64 diff --git a/clang/test/CIR/Lowering/ThroughMLIR/branch.cir b/clang/test/CIR/Lowering/ThroughMLIR/branch.cir index 2b78484627d5..89cd8849a3ca 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/branch.cir +++ b/clang/test/CIR/Lowering/ThroughMLIR/branch.cir @@ -13,9 +13,8 @@ cir.func @foo(%arg0: !cir.bool) -> !s32i { } // MLIR: module { -// MLIR-NEXT: func.func @foo(%arg0: i8) -> i32 -// MLIR-NEXT: %0 = arith.trunci %arg0 : i8 to i1 -// MLIR-NEXT: cf.cond_br %0, ^bb1, ^bb2 +// MLIR-NEXT: func.func @foo(%arg0: i1) -> i32 +// MLIR-NEXT: cf.cond_br %arg0, ^bb1, ^bb2 // MLIR-NEXT: ^bb1: // pred: ^bb0 // MLIR-NEXT: %c1_i32 = arith.constant 1 : i32 // MLIR-NEXT: return %c1_i32 : i32 @@ -25,13 +24,12 @@ cir.func @foo(%arg0: !cir.bool) -> !s32i { // MLIR-NEXT: } // MLIR-NEXT: } -// LLVM: define i32 @foo(i8 %0) -// LLVM-NEXT: %2 = trunc i8 %0 to i1 -// LLVM-NEXT: br i1 %2, label %3, label %4 +// LLVM: define i32 @foo(i1 %0) +// LLVM-NEXT: br i1 %0, label %[[TRUE:.*]], label %[[FALSE:.*]] // LLVM-EMPTY: -// LLVM-NEXT: 3: ; preds = %1 +// LLVM-NEXT: [[TRUE]]: // LLVM-NEXT: ret i32 1 // LLVM-EMPTY: -// LLVM-NEXT: 4: ; preds = %1 +// LLVM-NEXT: [[FALSE]]: // LLVM-NEXT: ret i32 0 // LLVM-NEXT: } diff --git a/clang/test/CIR/Lowering/ThroughMLIR/cast.cir b/clang/test/CIR/Lowering/ThroughMLIR/cast.cir index 18452a456880..8812e77dd583 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/cast.cir +++ b/clang/test/CIR/Lowering/ThroughMLIR/cast.cir @@ -7,8 +7,8 @@ !u16i = !cir.int !u8i = !cir.int module { - // MLIR-LABEL: func.func @cast_int_to_bool(%arg0: i32) -> i8 - // LLVM-LABEL: define i8 @cast_int_to_bool(i32 %0) + // MLIR-LABEL: func.func @cast_int_to_bool(%arg0: i32) -> i1 + // LLVM-LABEL: define i1 @cast_int_to_bool(i32 %0) cir.func @cast_int_to_bool(%i : !u32i) -> !cir.bool { // MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32 // MLIR-NEXT: arith.cmpi ne, %arg0, %[[ZERO]] @@ -71,8 +71,8 @@ module { %1 = cir.cast(floating, %f : !cir.float), !cir.double cir.return %1 : !cir.double } - // MLIR-LABEL: func.func @cast_float_to_bool(%arg0: f32) -> i8 - // LLVM-LABEL: define i8 @cast_float_to_bool(float %0) + // MLIR-LABEL: func.func @cast_float_to_bool(%arg0: f32) -> i1 + // LLVM-LABEL: define i1 @cast_float_to_bool(float %0) cir.func @cast_float_to_bool(%f : !cir.float) -> !cir.bool { // MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 // MLIR-NEXT: arith.cmpf une, %arg0, %[[ZERO]] : f32 @@ -81,29 +81,29 @@ module { %1 = cir.cast(float_to_bool, %f : !cir.float), !cir.bool cir.return %1 : !cir.bool } - // MLIR-LABEL: func.func @cast_bool_to_int8(%arg0: i8) -> i8 - // LLVM-LABEL: define i8 @cast_bool_to_int8(i8 %0) + // MLIR-LABEL: func.func @cast_bool_to_int8(%arg0: i1) -> i8 + // LLVM-LABEL: define i8 @cast_bool_to_int8(i1 %0) cir.func @cast_bool_to_int8(%b : !cir.bool) -> !u8i { - // MLIR-NEXT: arith.bitcast %arg0 : i8 to i8 - // LLVM-NEXT: ret i8 %0 + // MLIR-NEXT: arith.extui %arg0 : i1 to i8 + // LLVM-NEXT: zext i1 %0 to i8 %1 = cir.cast(bool_to_int, %b : !cir.bool), !u8i cir.return %1 : !u8i } - // MLIR-LABEL: func.func @cast_bool_to_int(%arg0: i8) -> i32 - // LLVM-LABEL: define i32 @cast_bool_to_int(i8 %0) + // MLIR-LABEL: func.func @cast_bool_to_int(%arg0: i1) -> i32 + // LLVM-LABEL: define i32 @cast_bool_to_int(i1 %0) cir.func @cast_bool_to_int(%b : !cir.bool) -> !u32i { - // MLIR-NEXT: arith.extui %arg0 : i8 to i32 - // LLVM-NEXT: zext i8 %0 to i32 + // MLIR-NEXT: arith.extui %arg0 : i1 to i32 + // LLVM-NEXT: zext i1 %0 to i32 %1 = cir.cast(bool_to_int, %b : !cir.bool), !u32i cir.return %1 : !u32i } - // MLIR-LABEL: func.func @cast_bool_to_float(%arg0: i8) -> f32 - // LLVM-LABEL: define float @cast_bool_to_float(i8 %0) + // MLIR-LABEL: func.func @cast_bool_to_float(%arg0: i1) -> f32 + // LLVM-LABEL: define float @cast_bool_to_float(i1 %0) cir.func @cast_bool_to_float(%b : !cir.bool) -> !cir.float { - // MLIR-NEXT: arith.uitofp %arg0 : i8 to f32 - // LLVM-NEXT: uitofp i8 %0 to float + // MLIR-NEXT: arith.uitofp %arg0 : i1 to f32 + // LLVM-NEXT: uitofp i1 %0 to float %1 = cir.cast(bool_to_float, %b : !cir.bool), !cir.float cir.return %1 : !cir.float diff --git a/clang/test/CIR/Lowering/ThroughMLIR/cmp.cpp b/clang/test/CIR/Lowering/ThroughMLIR/cmp.cpp index fcb9247bfb8f..607f8ad5005f 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/cmp.cpp +++ b/clang/test/CIR/Lowering/ThroughMLIR/cmp.cpp @@ -7,10 +7,10 @@ bool testSignedIntCmpOps(int a, int b) { // CHECK: %[[ALLOC3:.+]] = memref.alloca() {alignment = 1 : i64} : memref // CHECK: %[[ALLOC4:.+]] = memref.alloca() {alignment = 1 : i64} : memref // CHECK: memref.store %arg0, %[[ALLOC1]][] : memref - // CHECK: memref.store %arg1, %[[ALLOC2]][] : memref - + // CHECK: memref.store %arg1, %[[ALLOC2]][] : memref + bool x = a == b; - + // CHECK: %[[LOAD0:.+]] = memref.load %[[ALLOC1]][] : memref // CHECK: %[[LOAD1:.+]] = memref.load %[[ALLOC2]][] : memref // CHECK: %[[CMP0:.+]] = arith.cmpi eq, %[[LOAD0]], %[[LOAD1]] : i32 @@ -57,11 +57,8 @@ bool testSignedIntCmpOps(int a, int b) { // CHECK: %[[EXT5:.+]] = arith.extui %[[CMP5]] : i1 to i8 // CHECK: memref.store %[[EXT5]], %[[ALLOC4]][] : memref - // CHECK: %[[LOAD12:.+]] = memref.load %[[ALLOC4]][] : memref - // CHECK: memref.store %[[LOAD12]], %[[ALLOC3]][] : memref - // CHECK: %[[LOAD13:.+]] = memref.load %[[ALLOC3]][] : memref - // CHECK: return %[[LOAD13]] : i8 return x; + // CHECK: return } bool testUnSignedIntBinOps(unsigned a, unsigned b) { @@ -71,7 +68,7 @@ bool testUnSignedIntBinOps(unsigned a, unsigned b) { // CHECK: %[[ALLOC4:.+]] = memref.alloca() {alignment = 1 : i64} : memref // CHECK: memref.store %arg0, %[[ALLOC1]][] : memref // CHECK: memref.store %arg1, %[[ALLOC2]][] : memref - + bool x = a == b; // CHECK: %[[LOAD0:.+]] = memref.load %[[ALLOC1]][] : memref @@ -182,4 +179,4 @@ bool testFloatingPointCmpOps(float a, float b) { return x; // CHECK: return -} \ No newline at end of file +} diff --git a/clang/test/CIR/Lowering/ThroughMLIR/doWhile.c b/clang/test/CIR/Lowering/ThroughMLIR/doWhile.c index cf1e275caece..5974734740a2 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/doWhile.c +++ b/clang/test/CIR/Lowering/ThroughMLIR/doWhile.c @@ -45,9 +45,7 @@ void nestedDoWhile() { // CHECK: %[[VAR4:.+]] = memref.load %[[ALLOC1]][] : memref // CHECK: %[[C10_I32:.+]] = arith.constant 10 : i32 // CHECK: %[[CMP:.+]] = arith.cmpi sle, %[[VAR4]], %[[C10_I32]] : i32 -// CHECK: %[[EXT1:.+]] = arith.extui %[[CMP]] : i1 to i8 -// CHECK: %[[TRUNC:.+]] = arith.trunci %[[EXT1]] : i8 to i1 -// CHECK: scf.condition(%[[TRUNC]]) +// CHECK: scf.condition(%[[CMP]]) // CHECK: } do { // CHECK: scf.yield // CHECK: } @@ -76,9 +74,7 @@ void nestedDoWhile() { // CHECK: %[[EIGHT:.+]] = memref.load %[[alloca_0]][] : memref // CHECK: %[[C2_I32_3:.+]] = arith.constant 2 : i32 // CHECK: %[[NINE:.+]] = arith.cmpi slt, %[[EIGHT]], %[[C2_I32_3]] : i32 -// CHECK: %[[TWELVE:.+]] = arith.extui %[[NINE]] : i1 to i8 -// CHECK: %[[THIRTEEN:.+]] = arith.trunci %[[TWELVE]] : i8 to i1 -// CHECK: scf.condition(%[[THIRTEEN]]) +// CHECK: scf.condition(%[[NINE]]) // CHECK: } do { // CHECK: %[[EIGHT]] = memref.load %[[alloca_0]][] : memref // CHECK: %[[C1_I32_3:.+]] = arith.constant 1 : i32 @@ -91,9 +87,7 @@ void nestedDoWhile() { // CHECK: %[[TWO:.+]] = memref.load %[[alloca]][] : memref // CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32 // CHECK: %[[THREE:.+]] = arith.cmpi slt, %[[TWO]], %[[C2_I32]] : i32 -// CHECK: %[[SIX:.+]] = arith.extui %[[THREE]] : i1 to i8 -// CHECK: %[[SEVEN:.+]] = arith.trunci %[[SIX]] : i8 to i1 -// CHECK: scf.condition(%[[SEVEN]]) +// CHECK: scf.condition(%[[THREE]]) // CHECK: } do { // CHECK: scf.yield // CHECK: } diff --git a/clang/test/CIR/Lowering/ThroughMLIR/if.c b/clang/test/CIR/Lowering/ThroughMLIR/if.c index 8e88346c727f..dec3f9968d6a 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/if.c +++ b/clang/test/CIR/Lowering/ThroughMLIR/if.c @@ -22,9 +22,7 @@ void foo() { //CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref //CHECK: %[[C0_I32_1:.+]] = arith.constant 0 : i32 //CHECK: %[[ONE:.+]] = arith.cmpi sgt, %[[ZERO]], %[[C0_I32_1]] : i32 -//CHECK: %[[FOUR:.+]] = arith.extui %[[ONE]] : i1 to i8 -//CHECK: %[[FIVE:.+]] = arith.trunci %[[FOUR]] : i8 to i1 -//CHECK: scf.if %[[FIVE]] { +//CHECK: scf.if %[[ONE]] { //CHECK: %[[SIX:.+]] = memref.load %[[alloca_0]][] : memref //CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32 //CHECK: %[[SEVEN:.+]] = arith.addi %[[SIX]], %[[C1_I32]] : i32 @@ -58,9 +56,7 @@ void foo2() { //CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref //CHECK: %[[C3_I32:.+]] = arith.constant 3 : i32 //CHECK: %[[ONE:.+]] = arith.cmpi slt, %[[ZERO]], %[[C3_I32]] : i32 -//CHECK: %[[FOUR:.+]] = arith.extui %[[ONE]] : i1 to i8 -//CHECK: %[[FIVE]] = arith.trunci %[[FOUR]] : i8 to i1 -//CHECK: scf.if %[[FIVE]] { +//CHECK: scf.if %[[ONE]] { //CHECK: %[[SIX:.+]] = memref.load %[[alloca_0]][] : memref //CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32 //CHECK: %[[SEVEN:.+]] = arith.addi %[[SIX]], %[[C1_I32]] : i32 @@ -95,9 +91,7 @@ void foo3() { //CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref //CHECK: %[[C3_I32:.+]] = arith.constant 3 : i32 //CHECK: %[[ONE:.+]] = arith.cmpi slt, %[[ZERO]], %[[C3_I32]] : i32 -//CHECK: %[[FOUR:.+]] = arith.extui %[[ONE]] : i1 to i8 -//CHECK: %[[FIVE]] = arith.trunci %[[FOUR]] : i8 to i1 -//CHECK: scf.if %[[FIVE]] { +//CHECK: scf.if %[[ONE]] { //CHECK: %[[alloca_2:.+]] = memref.alloca() {alignment = 4 : i64} : memref //CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32 //CHECK: memref.store %[[C1_I32]], %[[alloca_2]][] : memref @@ -105,9 +99,7 @@ void foo3() { //CHECK: %[[SIX:.+]] = memref.load %[[alloca_2]][] : memref //CHECK: %[[C2_I32_3:.+]] = arith.constant 2 : i32 //CHECK: %[[SEVEN:.+]] = arith.cmpi sgt, %[[SIX]], %[[C2_I32_3]] : i32 -//CHECK: %[[TEN:.+]] = arith.extui %[[SEVEN]] : i1 to i8 -//CHECK: %[[ELEVEN:.+]] = arith.trunci %[[TEN]] : i8 to i1 -//CHECK: scf.if %[[ELEVEN]] { +//CHECK: scf.if %[[SEVEN]] { //CHECK: %[[TWELVE:.+]] = memref.load %[[alloca_0]][] : memref //CHECK: %[[C1_I32_5:.+]] = arith.constant 1 : i32 //CHECK: %[[THIRTEEN:.+]] = arith.addi %[[TWELVE]], %[[C1_I32_5]] : i32 diff --git a/clang/test/CIR/Lowering/ThroughMLIR/tenary.cir b/clang/test/CIR/Lowering/ThroughMLIR/tenary.cir index ce6f466aebc9..819b4c3b941e 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/tenary.cir +++ b/clang/test/CIR/Lowering/ThroughMLIR/tenary.cir @@ -26,16 +26,14 @@ cir.func @_Z1xi(%arg0: !s32i) -> !s32i { } // MLIR: %1 = arith.cmpi sgt, %0, %c0_i32 : i32 -// MLIR-NEXT: %2 = arith.extui %1 : i1 to i8 -// MLIR-NEXT: %3 = arith.trunci %2 : i8 to i1 -// MLIR-NEXT: %4 = scf.if %3 -> (i32) { +// MLIR-NEXT: %2 = scf.if %1 -> (i32) { // MLIR-NEXT: %c3_i32 = arith.constant 3 : i32 // MLIR-NEXT: scf.yield %c3_i32 : i32 // MLIR-NEXT: } else { // MLIR-NEXT: %c5_i32 = arith.constant 5 : i32 // MLIR-NEXT: scf.yield %c5_i32 : i32 // MLIR-NEXT: } -// MLIR-NEXT: memref.store %4, %alloca_0[] : memref +// MLIR-NEXT: memref.store %2, %alloca_0[] : memref // MLIR-CANONICALIZE: %[[CMP:.*]] = arith.cmpi sgt // MLIR-CANONICALIZE: arith.select %[[CMP]] diff --git a/clang/test/CIR/Lowering/ThroughMLIR/while.c b/clang/test/CIR/Lowering/ThroughMLIR/while.c index 5621e1fc7c4a..68454f3bea99 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/while.c +++ b/clang/test/CIR/Lowering/ThroughMLIR/while.c @@ -28,9 +28,7 @@ void nestedWhile() { //CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref //CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32 //CHECK: %[[ONE:.+]] = arith.cmpi slt, %[[ZERO:.+]], %[[C2_I32]] : i32 -//CHECK: %[[FOUR:.+]] = arith.extui %[[ONE:.+]] : i1 to i8 -//CHECK: %[[FIVE:.+]] = arith.trunci %[[FOUR:.+]] : i8 to i1 -//CHECK: scf.condition(%[[FIVE]]) +//CHECK: scf.condition(%[[ONE]]) //CHECK: } do { //CHECK: memref.alloca_scope { //CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref @@ -53,9 +51,7 @@ void nestedWhile() { //CHECK: %[[ZERO:.+]] = memref.load %alloca[] : memref //CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32 //CHECK: %[[ONE:.+]] = arith.cmpi slt, %[[ZERO]], %[[C2_I32]] : i32 -//CHECK: %[[FOUR:.+]] = arith.extui %[[ONE]] : i1 to i8 -//CHECK: %[[FIVE:.+]] = arith.trunci %[[FOUR]] : i8 to i1 -//CHECK: scf.condition(%[[FIVE]]) +//CHECK: scf.condition(%[[ONE]]) //CHECK: } do { //CHECK: memref.alloca_scope { //CHECK: %[[alloca_0:.+]] = memref.alloca() {alignment = 4 : i64} : memref @@ -65,9 +61,7 @@ void nestedWhile() { //CHECK: scf.while : () -> () { //CHECK: %{{.*}} = memref.load %[[alloca_0]][] : memref //CHECK: %[[C2_I32]] = arith.constant 2 : i32 -//CHECK: %{{.*}} = arith.cmpi slt, %{{.*}}, %[[C2_I32]] : i32 -//CHECK: %[[SIX:.+]] = arith.extui %{{.*}} : i1 to i8 -//CHECK: %[[SEVEN:.+]] = arith.trunci %[[SIX]] : i8 to i1 +//CHECK: %[[SEVEN:.*]] = arith.cmpi slt, %{{.*}}, %[[C2_I32]] : i32 //CHECK: scf.condition(%[[SEVEN]]) //CHECK: } do { //CHECK: %{{.*}} = memref.load %[[alloca_0]][] : memref From 3450f9ea6cda6694ca05a32e645b2f5ba235d878 Mon Sep 17 00:00:00 2001 From: Or Biri Date: Sun, 22 Dec 2024 00:55:49 +0200 Subject: [PATCH 2/2] [CIR][DirectToLLVM] Lower `cir.bool` to i1 --- clang/lib/CIR/CodeGen/CIRGenModule.cpp | 2 +- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 297 +++++++++++------- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 54 +++- clang/test/CIR/CodeGen/atomic-xchg-field.c | 8 +- clang/test/CIR/CodeGen/bf16-ops.c | 10 +- clang/test/CIR/CodeGen/builtin-assume.cpp | 4 +- clang/test/CIR/CodeGen/builtin-constant-p.c | 3 +- clang/test/CIR/CodeGen/complex-arithmetic.c | 9 +- clang/test/CIR/CodeGen/complex-cast.c | 8 +- clang/test/CIR/CodeGen/new-null.cpp | 2 +- .../CodeGen/pointer-to-data-member-cast.cpp | 8 +- clang/test/CIR/CodeGen/static.cpp | 4 +- clang/test/CIR/IR/invalid.cir | 6 +- clang/test/CIR/Lowering/binop-overflow.cir | 20 +- clang/test/CIR/Lowering/bool.cir | 9 +- clang/test/CIR/Lowering/branch.cir | 22 +- clang/test/CIR/Lowering/brcond.cir | 25 +- clang/test/CIR/Lowering/cast.cir | 16 +- clang/test/CIR/Lowering/const-array.cir | 5 + clang/test/CIR/Lowering/const.cir | 7 +- clang/test/CIR/Lowering/loadstorealloca.cir | 20 +- clang/test/CIR/Lowering/ptrstride.cir | 10 + clang/test/CIR/Lowering/select.cir | 24 +- clang/test/CIR/Lowering/struct.cir | 19 ++ clang/test/CIR/Lowering/unary-not.cir | 24 +- clang/test/CIR/Lowering/unions.cir | 3 +- 26 files changed, 388 insertions(+), 231 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp index b35ef11c7782..085e2c237ee7 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp @@ -820,7 +820,7 @@ void CIRGenModule::replaceGlobal(cir::GlobalOp Old, cir::GlobalOp New) { mlir::Type ptrTy = builder.getPointerTo(OldTy); mlir::Value cast = builder.createBitcast(GGO->getLoc(), UseOpResultValue, ptrTy); - UseOpResultValue.replaceAllUsesExcept(cast, {cast.getDefiningOp()}); + UseOpResultValue.replaceAllUsesExcept(cast, cast.getDefiningOp()); } } } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index c933035cd850..236cb602f735 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -349,12 +349,81 @@ unsigned getGlobalOpTargetAddrSpace(mlir::ConversionPatternRewriter &rewriter, .getAddressSpace(); } +/// Given a type convertor and a data layout, convert the given type to a type +/// that is suitable for memory operations. For example, this can be used to +/// lower cir.bool accesses to i8. +static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter, + mlir::DataLayout const &dataLayout, + mlir::Type type) { + // TODO(cir): Handle other types similarly to clang's codegen + // convertTypeForMemory + if (isa(type)) { + return mlir::IntegerType::get(type.getContext(), + dataLayout.getTypeSizeInBits(type)); + } + + return converter.convertType(type); +} + +/// Emits the value from memory as expected by its users. Should be called when +/// the memory represetnation of a CIR type is not equal to its scalar +/// representation. +static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter, + mlir::DataLayout const &dataLayout, + cir::LoadOp op, mlir::Value value) { + + // TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory + if (auto boolTy = mlir::dyn_cast(op.getResult().getType())) { + // Create a cast value from specified size in datalayout to i1 + assert(value.getType().isInteger(dataLayout.getTypeSizeInBits(boolTy))); + return createIntCast(rewriter, value, rewriter.getI1Type()); + } + + return value; +} + +/// Emits a value to memory with the expected scalar type. Should be called when +/// the memory represetnation of a CIR type is not equal to its scalar +/// representation. +static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter, + mlir::DataLayout const &dataLayout, + mlir::Type origType, mlir::Value value) { + + // TODO(cir): Handle other types similarly to clang's codegen EmitToMemory + if (auto boolTy = mlir::dyn_cast(origType)) { + // Create zext of value from i1 to i8 + auto memType = + rewriter.getIntegerType(dataLayout.getTypeSizeInBits(boolTy)); + return createIntCast(rewriter, value, memType); + } + + return value; +} + } // namespace //===----------------------------------------------------------------------===// // Visitors for Lowering CIR Const Attributes //===----------------------------------------------------------------------===// +/// Emits a value to memory with the expected scalar type. Should be called when +/// the memory represetnation of a CIR attribute's type is not equal to its +/// scalar representation. +static mlir::Value +emitCirAttrToMemory(mlir::Operation *parentOp, mlir::Attribute attr, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter, + mlir::DataLayout const &dataLayout) { + + mlir::Value loweredValue = + lowerCirAttrAsValue(parentOp, attr, rewriter, converter, dataLayout); + if (auto boolAttr = mlir::dyn_cast(attr)) { + return emitToMemory(rewriter, dataLayout, boolAttr.getType(), loweredValue); + } + + return loweredValue; +} + /// Switches on the type of attribute and calls the appropriate conversion. /// IntAttr visitor. @@ -439,14 +508,16 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::PoisonAttr poisonAttr, static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct, mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { + const mlir::TypeConverter *converter, + mlir::DataLayout const &dataLayout) { auto llvmTy = converter->convertType(constStruct.getType()); auto loc = parentOp->getLoc(); mlir::Value result = rewriter.create(loc, llvmTy); // Iteratively lower each constant element of the struct. for (auto [idx, elt] : llvm::enumerate(constStruct.getMembers())) { - mlir::Value init = lowerCirAttrAsValue(parentOp, elt, rewriter, converter); + mlir::Value init = + emitCirAttrToMemory(parentOp, elt, rewriter, converter, dataLayout); result = rewriter.create(loc, result, init, idx); } @@ -457,13 +528,15 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct, static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::VTableAttr vtableArr, mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { + const mlir::TypeConverter *converter, + mlir::DataLayout const &dataLayout) { auto llvmTy = converter->convertType(vtableArr.getType()); auto loc = parentOp->getLoc(); mlir::Value result = rewriter.create(loc, llvmTy); for (auto [idx, elt] : llvm::enumerate(vtableArr.getVtableData())) { - mlir::Value init = lowerCirAttrAsValue(parentOp, elt, rewriter, converter); + mlir::Value init = + lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout); result = rewriter.create(loc, result, init, idx); } @@ -474,13 +547,15 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::VTableAttr vtableArr, static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::TypeInfoAttr typeinfoArr, mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { + const mlir::TypeConverter *converter, + mlir::DataLayout const &dataLayout) { auto llvmTy = converter->convertType(typeinfoArr.getType()); auto loc = parentOp->getLoc(); mlir::Value result = rewriter.create(loc, llvmTy); for (auto [idx, elt] : llvm::enumerate(typeinfoArr.getData())) { - mlir::Value init = lowerCirAttrAsValue(parentOp, elt, rewriter, converter); + mlir::Value init = + lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout); result = rewriter.create(loc, result, init, idx); } @@ -491,7 +566,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::TypeInfoAttr typeinfoArr, static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr, mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { + const mlir::TypeConverter *converter, + mlir::DataLayout const &dataLayout) { auto llvmTy = converter->convertType(constArr.getType()); auto loc = parentOp->getLoc(); mlir::Value result; @@ -508,7 +584,7 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr, if (auto arrayAttr = mlir::dyn_cast(constArr.getElts())) { for (auto [idx, elt] : llvm::enumerate(arrayAttr)) { mlir::Value init = - lowerCirAttrAsValue(parentOp, elt, rewriter, converter); + emitCirAttrToMemory(parentOp, elt, rewriter, converter, dataLayout); result = rewriter.create(loc, result, init, idx); } @@ -565,7 +641,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec, static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr, mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { + const mlir::TypeConverter *converter, + mlir::DataLayout const &dataLayout) { auto module = parentOp->getParentOfType(); mlir::Type sourceType; unsigned sourceAddrSpace = 0; @@ -577,7 +654,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr, symName = llvmSymbol.getSymName(); sourceAddrSpace = llvmSymbol.getAddrSpace(); } else if (auto cirSymbol = dyn_cast(sourceSymbol)) { - sourceType = converter->convertType(cirSymbol.getSymType()); + sourceType = + convertTypeForMemory(*converter, dataLayout, cirSymbol.getSymType()); symName = cirSymbol.getSymName(); sourceAddrSpace = getGlobalOpTargetAddrSpace(rewriter, converter, cirSymbol); @@ -622,7 +700,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr, auto ptrTy = mlir::dyn_cast(globalAttr.getType()); assert(ptrTy && "Expecting pointer type in GlobalViewAttr"); - auto llvmEltTy = converter->convertType(ptrTy.getPointee()); + auto llvmEltTy = + convertTypeForMemory(*converter, dataLayout, ptrTy.getPointee()); if (llvmEltTy == sourceType) return addrOp; @@ -635,7 +714,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr, /// Switches on the type of attribute and calls the appropriate conversion. mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { + const mlir::TypeConverter *converter, + mlir::DataLayout const &dataLayout) { if (const auto intAttr = mlir::dyn_cast(attr)) return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter); if (const auto fltAttr = mlir::dyn_cast(attr)) @@ -643,9 +723,11 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, if (const auto ptrAttr = mlir::dyn_cast(attr)) return lowerCirAttrAsValue(parentOp, ptrAttr, rewriter, converter); if (const auto constStruct = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, constStruct, rewriter, converter); + return lowerCirAttrAsValue(parentOp, constStruct, rewriter, converter, + dataLayout); if (const auto constArr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, constArr, rewriter, converter); + return lowerCirAttrAsValue(parentOp, constArr, rewriter, converter, + dataLayout); if (const auto constVec = mlir::dyn_cast(attr)) return lowerCirAttrAsValue(parentOp, constVec, rewriter, converter); if (const auto boolAttr = mlir::dyn_cast(attr)) @@ -657,11 +739,14 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, if (const auto poisonAttr = mlir::dyn_cast(attr)) return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter); if (const auto globalAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter); + return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter, + dataLayout); if (const auto vtableAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter); + return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter, + dataLayout); if (const auto typeinfoAttr = mlir::dyn_cast(attr)) - return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter); + return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter, + dataLayout); llvm_unreachable("unhandled attribute type"); } @@ -816,7 +901,8 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite( mlir::ConversionPatternRewriter &rewriter) const { auto *tc = getTypeConverter(); const auto resultTy = tc->convertType(ptrStrideOp.getType()); - auto elementTy = tc->convertType(ptrStrideOp.getElementTy()); + auto elementTy = + convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementTy()); auto *ctx = elementTy.getContext(); // void and function types doesn't really have a layout to use in GEPs, @@ -1012,8 +1098,7 @@ mlir::LogicalResult CIRToLLVMBrCondOpLowering::matchAndRewrite( } if (!i1Condition) - i1Condition = rewriter.create( - brOp.getLoc(), rewriter.getI1Type(), adaptor.getCond()); + i1Condition = adaptor.getCond(); rewriter.replaceOpWithNewOp( brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(), @@ -1040,7 +1125,8 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( const auto ptrTy = mlir::cast(castOp.getType()); auto sourceValue = adaptor.getOperands().front(); auto targetType = convertTy(ptrTy); - auto elementTy = convertTy(ptrTy.getPointee()); + auto elementTy = convertTypeForMemory(*getTypeConverter(), dataLayout, + ptrTy.getPointee()); auto offset = llvm::SmallVector{0}; rewriter.replaceOpWithNewOp( castOp, targetType, elementTy, sourceValue, offset); @@ -1111,9 +1197,7 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( return mlir::success(); } case cir::CastKind::float_to_bool: { - auto dstTy = mlir::cast(castOp.getType()); auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); auto kind = mlir::LLVM::FCmpPredicate::une; // Check if float is not equal to zero. @@ -1122,10 +1206,9 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0)); // Extend comparison result to either bool (C++) or int (C). - mlir::Value cmpResult = rewriter.create( - castOp.getLoc(), kind, llvmSrcVal, zeroFloat); - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - cmpResult); + rewriter.replaceOpWithNewOp(castOp, kind, llvmSrcVal, + zeroFloat); + return mlir::success(); } case cir::CastKind::bool_to_int: { @@ -1434,7 +1517,8 @@ mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite( op.getLoc(), typeConverter->convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - auto elementTy = getTypeConverter()->convertType(op.getAllocaType()); + auto elementTy = + convertTypeForMemory(*getTypeConverter(), dataLayout, op.getAllocaType()); auto resultTy = getTypeConverter()->convertType(op.getResult().getType()); // Verification between the CIR alloca AS and the one from data layout. { @@ -1489,7 +1573,8 @@ getLLVMMemOrder(std::optional &memorder) { mlir::LogicalResult CIRToLLVMLoadOpLowering::matchAndRewrite( cir::LoadOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - const auto llvmTy = getTypeConverter()->convertType(op.getResult().getType()); + const auto llvmTy = convertTypeForMemory(*getTypeConverter(), dataLayout, + op.getResult().getType()); auto memorder = op.getMemOrder(); auto ordering = getLLVMMemOrder(memorder); auto alignOpt = op.getAlignment(); @@ -1512,10 +1597,15 @@ mlir::LogicalResult CIRToLLVMLoadOpLowering::matchAndRewrite( } // TODO: nontemporal, syncscope. - rewriter.replaceOpWithNewOp( - op, llvmTy, adaptor.getAddr(), /* alignment */ alignment, + auto newLoad = rewriter.create( + op->getLoc(), llvmTy, adaptor.getAddr(), /* alignment */ alignment, op.getIsVolatile(), /* nontemporal */ false, /* invariant */ false, /* invariantGroup */ invariant, ordering); + + // Convert adapted result to its original type if needed. + mlir::Value result = + emitFromMemory(rewriter, dataLayout, op, newLoad.getResult()); + rewriter.replaceOp(op, result); return mlir::LogicalResult::success(); } @@ -1546,9 +1636,12 @@ mlir::LogicalResult CIRToLLVMStoreOpLowering::matchAndRewrite( invariant = addrAllocaOp && addrAllocaOp.getConstant(); } + // Convert adapted value to its memory type if needed. + mlir::Value value = emitToMemory(rewriter, dataLayout, + op.getValue().getType(), adaptor.getValue()); // TODO: nontemporal, syncscope. rewriter.replaceOpWithNewOp( - op, adaptor.getValue(), adaptor.getAddr(), alignment, op.getIsVolatile(), + op, value, adaptor.getAddr(), alignment, op.getIsVolatile(), /* nontemporal */ false, /* invariantGroup */ invariant, ordering); return mlir::LogicalResult::success(); } @@ -1569,9 +1662,9 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( // Regardless of the type, we should lower the constant of poison value // into PoisonOp. - if (mlir::isa(attr)) { + if (auto poisonAttr = mlir::dyn_cast(attr)) { rewriter.replaceOp( - op, lowerCirAttrAsValue(op, attr, rewriter, getTypeConverter())); + op, lowerCirAttrAsValue(op, poisonAttr, rewriter, getTypeConverter())); return mlir::success(); } @@ -1629,7 +1722,8 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( } // Lower GlobalViewAttr to llvm.mlir.addressof if (auto gv = mlir::dyn_cast(op.getValue())) { - auto newOp = lowerCirAttrAsValue(op, gv, rewriter, getTypeConverter()); + auto newOp = + lowerCirAttrAsValue(op, gv, rewriter, getTypeConverter(), dataLayout); rewriter.replaceOp(op, newOp); return mlir::success(); } @@ -1655,16 +1749,16 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( std::optional denseAttr; if (constArr && hasTrailingZeros(constArr)) { - auto newOp = - lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter()); + auto newOp = lowerCirAttrAsValue(op, constArr, rewriter, + getTypeConverter(), dataLayout); rewriter.replaceOp(op, newOp); return mlir::success(); } else if (constArr && (denseAttr = lowerConstArrayAttr(constArr, typeConverter))) { attr = denseAttr.value(); } else { - auto initVal = - lowerCirAttrAsValue(op, op.getValue(), rewriter, typeConverter); + auto initVal = lowerCirAttrAsValue(op, op.getValue(), rewriter, + typeConverter, dataLayout); rewriter.replaceAllUsesWith(op, initVal); rewriter.eraseOp(op); return mlir::success(); @@ -1675,14 +1769,16 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( // initializer would be a global constant that is memcopied. Here we just // define a local constant with llvm.undef that will be stored into the // stack. - auto initVal = lowerCirAttrAsValue(op, structAttr, rewriter, typeConverter); + auto initVal = lowerCirAttrAsValue(op, structAttr, rewriter, typeConverter, + dataLayout); rewriter.replaceAllUsesWith(op, initVal); rewriter.eraseOp(op); return mlir::success(); } else if (auto strTy = mlir::dyn_cast(op.getType())) { auto attr = op.getValue(); if (mlir::isa(attr)) { - auto initVal = lowerCirAttrAsValue(op, attr, rewriter, typeConverter); + auto initVal = + lowerCirAttrAsValue(op, attr, rewriter, typeConverter, dataLayout); rewriter.replaceAllUsesWith(op, initVal); rewriter.eraseOp(op); return mlir::success(); @@ -1692,7 +1788,7 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( << op.getType(); } else if (const auto vecTy = mlir::dyn_cast(op.getType())) { rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter, - getTypeConverter())); + getTypeConverter(), dataLayout)); return mlir::success(); } else return op.emitError() << "unsupported constant type " << op.getType(); @@ -2160,7 +2256,8 @@ mlir::LogicalResult CIRToLLVMSwitchFlatOpLowering::matchAndRewrite( /// insertion point to the end of the initializer block. void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp( cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const { - const auto llvmType = getTypeConverter()->convertType(op.getSymType()); + const auto llvmType = + convertTypeForMemory(*getTypeConverter(), dataLayout, op.getSymType()); SmallVector attributes; auto newGlobalOp = rewriter.replaceOpWithNewOp( op, llvmType, op.getConstant(), convertLinkage(op.getLinkage()), @@ -2178,7 +2275,10 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( mlir::ConversionPatternRewriter &rewriter) const { // Fetch required values to create LLVM op. - const auto llvmType = getTypeConverter()->convertType(op.getSymType()); + const auto CIRSymType = op.getSymType(); + + const auto llvmType = + convertTypeForMemory(*getTypeConverter(), dataLayout, CIRSymType); const auto isConst = op.getConstant(); const auto isDsoLocal = op.getDsolocal(); const auto linkage = convertLinkage(op.getLinkage()); @@ -2222,8 +2322,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( if (!(init = lowerConstArrayAttr(constArr, getTypeConverter()))) { setupRegionInitializedLLVMGlobalOp(op, rewriter); rewriter.create( - op->getLoc(), - lowerCirAttrAsValue(op, constArr, rewriter, typeConverter)); + op->getLoc(), lowerCirAttrAsValue(op, constArr, rewriter, + typeConverter, dataLayout)); return mlir::success(); } } else { @@ -2247,7 +2347,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( // should be updated. For now, we use a custom op to initialize globals // to the appropriate value. setupRegionInitializedLLVMGlobalOp(op, rewriter); - auto value = lowerCirAttrAsValue(op, init.value(), rewriter, typeConverter); + auto value = lowerCirAttrAsValue(op, init.value(), rewriter, typeConverter, + dataLayout); rewriter.create(loc, value); return mlir::success(); } else if (auto dataMemberAttr = @@ -2265,27 +2366,28 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( mlir::dyn_cast(init.value())) { setupRegionInitializedLLVMGlobalOp(op, rewriter); rewriter.create( - op->getLoc(), - lowerCirAttrAsValue(op, structAttr, rewriter, typeConverter)); + op->getLoc(), lowerCirAttrAsValue(op, structAttr, rewriter, + typeConverter, dataLayout)); return mlir::success(); } else if (auto attr = mlir::dyn_cast(init.value())) { setupRegionInitializedLLVMGlobalOp(op, rewriter); rewriter.create( - loc, lowerCirAttrAsValue(op, attr, rewriter, typeConverter)); + loc, + lowerCirAttrAsValue(op, attr, rewriter, typeConverter, dataLayout)); return mlir::success(); } else if (const auto vtableAttr = mlir::dyn_cast(init.value())) { setupRegionInitializedLLVMGlobalOp(op, rewriter); rewriter.create( - op->getLoc(), - lowerCirAttrAsValue(op, vtableAttr, rewriter, typeConverter)); + op->getLoc(), lowerCirAttrAsValue(op, vtableAttr, rewriter, + typeConverter, dataLayout)); return mlir::success(); } else if (const auto typeinfoAttr = mlir::dyn_cast(init.value())) { setupRegionInitializedLLVMGlobalOp(op, rewriter); rewriter.create( - op->getLoc(), - lowerCirAttrAsValue(op, typeinfoAttr, rewriter, typeConverter)); + op->getLoc(), lowerCirAttrAsValue(op, typeinfoAttr, rewriter, + typeConverter, dataLayout)); return mlir::success(); } else { op.emitError() << "unsupported initializer '" << init.value() << "'"; @@ -2748,7 +2850,6 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( cir::CmpOp cmpOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { auto type = cmpOp.getLhs().getType(); - mlir::Value llResult; // Lower to LLVM comparison op. // if (auto intTy = mlir::dyn_cast(type)) { @@ -2757,27 +2858,21 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( ? mlir::cast(type).isSigned() : mlir::cast(type).isSigned(); auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned); - llResult = rewriter.create( - cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOpWithNewOp( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); } else if (auto ptrTy = mlir::dyn_cast(type)) { auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), /* isSigned=*/false); - llResult = rewriter.create( - cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOpWithNewOp( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); } else if (mlir::isa(type)) { auto kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); - llResult = rewriter.create( - cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOpWithNewOp( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); } else { return cmpOp.emitError() << "unsupported type for CmpOp: " << type; } - // LLVM comparison ops return i1, but cir::CmpOp returns the same type as - // the LHS value. Since this return value can be used later, we need to - // restore the type with the extension below. - auto llResultTy = getTypeConverter()->convertType(cmpOp.getType()); - rewriter.replaceOpWithNewOp(cmpOp, llResultTy, llResult); - return mlir::success(); } @@ -2827,8 +2922,7 @@ mlir::LogicalResult CIRToLLVMLLVMIntrinsicCallOpLowering::matchAndRewrite( mlir::LogicalResult CIRToLLVMAssumeOpLowering::matchAndRewrite( cir::AssumeOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - auto cond = rewriter.create( - op.getLoc(), rewriter.getI1Type(), adaptor.getPredicate()); + auto cond = adaptor.getPredicate(); rewriter.replaceOpWithNewOp(op, cond); return mlir::success(); } @@ -3063,9 +3157,7 @@ mlir::LogicalResult CIRToLLVMAtomicCmpXchgLowering::matchAndRewrite( auto cmp = rewriter.create( op.getLoc(), cmpxchg.getResult(), 1); - auto extCmp = rewriter.create(op.getLoc(), - rewriter.getI8Type(), cmp); - rewriter.replaceOp(op, {old, extCmp}); + rewriter.replaceOp(op, {old, cmp}); return mlir::success(); } @@ -3282,9 +3374,7 @@ mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite( } } - auto llvmCondition = rewriter.create( - op.getLoc(), mlir::IntegerType::get(op->getContext(), 1), - adaptor.getCondition()); + auto llvmCondition = adaptor.getCondition(); rewriter.replaceOpWithNewOp( op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue()); @@ -3497,8 +3587,8 @@ mlir::LogicalResult CIRToLLVMInlineAsmOpLowering::matchAndRewrite( std::vector attrs; auto typ = cast(cirOperands[i].getType()); - auto typAttr = - mlir::TypeAttr::get(getTypeConverter()->convertType(typ.getPointee())); + auto typAttr = mlir::TypeAttr::get(convertTypeForMemory( + *getTypeConverter(), dataLayout, typ.getPointee())); attrs.push_back(rewriter.getNamedAttr(llvmAttrName, typAttr)); auto newDict = rewriter.getDictionaryAttr(attrs); @@ -3645,13 +3735,7 @@ mlir::LogicalResult CIRToLLVMGetBitfieldOpLowering::matchAndRewrite( mlir::LogicalResult CIRToLLVMIsConstantOpLowering::matchAndRewrite( cir::IsConstantOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - // FIXME(cir): llvm.intr.is.constant returns i1 value but the LLVM Lowering - // expects that cir.bool type will be lowered as i8 type. - // So we have to insert zext here. - auto isConstantOP = - rewriter.create(op.getLoc(), adaptor.getVal()); - rewriter.replaceOpWithNewOp(op, rewriter.getI8Type(), - isConstantOP); + rewriter.replaceOpWithNewOp(op, adaptor.getVal()); return mlir::success(); } @@ -3871,17 +3955,7 @@ mlir::LogicalResult CIRToLLVMIsFPClassOpLowering::matchAndRewrite( auto flags = adaptor.getFlags(); auto retTy = rewriter.getI1Type(); - auto loc = op->getLoc(); - - auto intrinsic = - rewriter.create(loc, retTy, src, flags); - // FIMXE: CIR now will convert cir::BoolType to i8 type unconditionally. - // Remove this conversion after we fix - // https://github.com/llvm/clangir/issues/480 - auto converted = rewriter.create( - loc, rewriter.getI8Type(), intrinsic->getResult(0)); - - rewriter.replaceOp(op, converted); + rewriter.replaceOpWithNewOp(op, retTy, src, flags); return mlir::success(); } @@ -3960,17 +4034,28 @@ void populateCIRToLLVMConversionPatterns( patterns.add(converter, dataLayout, stringGlobalsMap, argStringGlobalsMap, argsVarMap, patterns.getContext()); + patterns.add< + // clang-format off + CIRToLLVMLoadOpLowering, + CIRToLLVMStoreOpLowering, + CIRToLLVMGlobalOpLowering, + CIRToLLVMConstantOpLowering + // clang-format on + >(converter, patterns.getContext(), lowerModule, dataLayout); patterns.add< // clang-format off CIRToLLVMBaseDataMemberOpLowering, - CIRToLLVMConstantOpLowering, CIRToLLVMDerivedDataMemberOpLowering, - CIRToLLVMGetRuntimeMemberOpLowering, - CIRToLLVMGlobalOpLowering, - CIRToLLVMLoadOpLowering, - CIRToLLVMStoreOpLowering + CIRToLLVMGetRuntimeMemberOpLowering // clang-format on >(converter, patterns.getContext(), lowerModule); + patterns.add< + // clang-format off + CIRToLLVMPtrStrideOpLowering, + CIRToLLVMCastOpLowering, + CIRToLLVMInlineAsmOpLowering + // clang-format on + >(converter, patterns.getContext(), dataLayout); patterns.add< // clang-format off CIRToLLVMAbsOpLowering, @@ -3994,7 +4079,6 @@ void populateCIRToLLVMConversionPatterns( CIRToLLVMBrOpLowering, CIRToLLVMByteswapOpLowering, CIRToLLVMCallOpLowering, - CIRToLLVMCastOpLowering, CIRToLLVMCatchParamOpLowering, CIRToLLVMClearCacheOpLowering, CIRToLLVMCmpOpLowering, @@ -4015,7 +4099,6 @@ void populateCIRToLLVMConversionPatterns( CIRToLLVMGetBitfieldOpLowering, CIRToLLVMGetGlobalOpLowering, CIRToLLVMGetMemberOpLowering, - CIRToLLVMInlineAsmOpLowering, CIRToLLVMIsConstantOpLowering, CIRToLLVMIsFPClassOpLowering, CIRToLLVMLLVMIntrinsicCallOpLowering, @@ -4029,7 +4112,6 @@ void populateCIRToLLVMConversionPatterns( CIRToLLVMPrefetchOpLowering, CIRToLLVMPtrDiffOpLowering, CIRToLLVMPtrMaskOpLowering, - CIRToLLVMPtrStrideOpLowering, CIRToLLVMResumeOpLowering, CIRToLLVMReturnAddrOpLowering, CIRToLLVMRotateOpLowering, @@ -4107,7 +4189,7 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter, return converter.convertType(abiType); }); converter.addConversion([&](cir::ArrayType type) -> mlir::Type { - auto ty = converter.convertType(type.getEltType()); + auto ty = convertTypeForMemory(converter, dataLayout, type.getEltType()); return mlir::LLVM::LLVMArrayType::get(ty, type.getSize()); }); converter.addConversion([&](cir::VectorType type) -> mlir::Type { @@ -4115,7 +4197,7 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter, return mlir::LLVM::getFixedVectorType(ty, type.getSize()); }); converter.addConversion([&](cir::BoolType type) -> mlir::Type { - return mlir::IntegerType::get(type.getContext(), 8, + return mlir::IntegerType::get(type.getContext(), 1, mlir::IntegerType::Signless); }); converter.addConversion([&](cir::IntType type) -> mlir::Type { @@ -4168,13 +4250,14 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter, // TODO(cir): This should be properly validated. case cir::StructType::Struct: for (auto ty : type.getMembers()) - llvmMembers.push_back(converter.convertType(ty)); + llvmMembers.push_back(convertTypeForMemory(converter, dataLayout, ty)); break; // Unions are lowered as only the largest member. case cir::StructType::Union: { auto largestMember = type.getLargestMember(dataLayout); if (largestMember) - llvmMembers.push_back(converter.convertType(largestMember)); + llvmMembers.push_back( + convertTypeForMemory(converter, dataLayout, largestMember)); break; } } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 48baae2ae799..12ded1f39c80 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -14,13 +14,18 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Transforms/DialectConversion.h" namespace cir { namespace direct { + +/// Convert a CIR attribute to an LLVM attribute. May use the datalayout for +/// lowering attributes to-be-stored in memory. mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter); + const mlir::TypeConverter *converter, + mlir::DataLayout const &dataLayout); mlir::LLVM::Linkage convertLinkage(cir::GlobalLinkageKind linkage); @@ -137,7 +142,13 @@ class CIRToLLVMMemSetInlineOpLowering class CIRToLLVMPtrStrideOpLowering : public mlir::OpConversionPattern { + mlir::DataLayout const &dataLayout; + public: + CIRToLLVMPtrStrideOpLowering(const mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context, + mlir::DataLayout const &dataLayout) + : OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {} using mlir::OpConversionPattern::OpConversionPattern; mlir::LogicalResult @@ -216,9 +227,15 @@ class CIRToLLVMBrCondOpLowering }; class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern { + mlir::DataLayout const &dataLayout; + mlir::Type convertTy(mlir::Type ty) const; public: + CIRToLLVMCastOpLowering(const mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context, + mlir::DataLayout const &dataLayout) + : OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {} using mlir::OpConversionPattern::OpConversionPattern; mlir::LogicalResult @@ -302,12 +319,15 @@ class CIRToLLVMAllocaOpLowering class CIRToLLVMLoadOpLowering : public mlir::OpConversionPattern { cir::LowerModule *lowerMod; + mlir::DataLayout const &dataLayout; public: CIRToLLVMLoadOpLowering(const mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, - cir::LowerModule *lowerModule) - : OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {} + cir::LowerModule *lowerModule, + mlir::DataLayout const &dataLayout) + : OpConversionPattern(typeConverter, context), lowerMod(lowerModule), + dataLayout(dataLayout) {} mlir::LogicalResult matchAndRewrite(cir::LoadOp op, OpAdaptor, @@ -317,12 +337,15 @@ class CIRToLLVMLoadOpLowering : public mlir::OpConversionPattern { class CIRToLLVMStoreOpLowering : public mlir::OpConversionPattern { cir::LowerModule *lowerMod; + mlir::DataLayout const &dataLayout; public: CIRToLLVMStoreOpLowering(const mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, - cir::LowerModule *lowerModule) - : OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {} + cir::LowerModule *lowerModule, + mlir::DataLayout const &dataLayout) + : OpConversionPattern(typeConverter, context), lowerMod(lowerModule), + dataLayout(dataLayout) {} mlir::LogicalResult matchAndRewrite(cir::StoreOp op, OpAdaptor, @@ -332,12 +355,15 @@ class CIRToLLVMStoreOpLowering class CIRToLLVMConstantOpLowering : public mlir::OpConversionPattern { cir::LowerModule *lowerMod; + mlir::DataLayout const &dataLayout; public: CIRToLLVMConstantOpLowering(const mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, - cir::LowerModule *lowerModule) - : OpConversionPattern(typeConverter, context), lowerMod(lowerModule) { + cir::LowerModule *lowerModule, + mlir::DataLayout const &dataLayout) + : OpConversionPattern(typeConverter, context), lowerMod(lowerModule), + dataLayout(dataLayout) { setHasBoundedRewriteRecursion(); } @@ -538,12 +564,15 @@ class CIRToLLVMSwitchFlatOpLowering class CIRToLLVMGlobalOpLowering : public mlir::OpConversionPattern { cir::LowerModule *lowerMod; + mlir::DataLayout const &dataLayout; public: CIRToLLVMGlobalOpLowering(const mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, - cir::LowerModule *lowerModule) - : OpConversionPattern(typeConverter, context), lowerMod(lowerModule) { + cir::LowerModule *lowerModule, + mlir::DataLayout const &dataLayout) + : OpConversionPattern(typeConverter, context), lowerMod(lowerModule), + dataLayout(dataLayout) { setHasBoundedRewriteRecursion(); } @@ -904,7 +933,14 @@ class CIRToLLVMTrapOpLowering : public mlir::OpConversionPattern { class CIRToLLVMInlineAsmOpLowering : public mlir::OpConversionPattern { + mlir::DataLayout const &dataLayout; + public: + CIRToLLVMInlineAsmOpLowering(const mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context, + mlir::DataLayout const &dataLayout) + : OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {} + using mlir::OpConversionPattern::OpConversionPattern; mlir::LogicalResult diff --git a/clang/test/CIR/CodeGen/atomic-xchg-field.c b/clang/test/CIR/CodeGen/atomic-xchg-field.c index c01abf7bae6e..fd9267632344 100644 --- a/clang/test/CIR/CodeGen/atomic-xchg-field.c +++ b/clang/test/CIR/CodeGen/atomic-xchg-field.c @@ -58,16 +58,14 @@ void structAtomicExchange(unsigned referenceCount, wPtr item) { // LLVM: %[[RES:.*]] = cmpxchg weak ptr %9, i32 %[[EXP]], i32 %[[DES]] seq_cst seq_cst // LLVM: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0 // LLVM: %[[CMP:.*]] = extractvalue { i32, i1 } %[[RES]], 1 -// LLVM: %[[Z:.*]] = zext i1 %[[CMP]] to i8 -// LLVM: %[[X:.*]] = xor i8 %[[Z]], 1 -// LLVM: %[[FAIL:.*]] = trunc i8 %[[X]] to i1 - -// LLVM: br i1 %[[FAIL:.*]], label %[[STORE_OLD:.*]], label %[[CONTINUE:.*]] +// LLVM: %[[FAIL:.*]] = xor i1 %[[CMP]], true +// LLVM: br i1 %[[FAIL]], label %[[STORE_OLD:.*]], label %[[CONTINUE:.*]] // LLVM: [[STORE_OLD]]: // LLVM: store i32 %[[OLD]], ptr // LLVM: br label %[[CONTINUE]] // LLVM: [[CONTINUE]]: +// LLVM: %[[Z:.*]] = zext i1 %[[CMP]] to i8 // LLVM: store i8 %[[Z]], ptr {{.*}}, align 1 // LLVM: ret void diff --git a/clang/test/CIR/CodeGen/bf16-ops.c b/clang/test/CIR/CodeGen/bf16-ops.c index 406446b778eb..d0c051a8d5e5 100644 --- a/clang/test/CIR/CodeGen/bf16-ops.c +++ b/clang/test/CIR/CodeGen/bf16-ops.c @@ -41,14 +41,12 @@ void foo(void) { // NATIVE-NEXT: %{{.+}} = cir.cast(integral, %[[#C]] : !s32i), !u32i // NONATIVE-LLVM: %[[#A:]] = fcmp une bfloat %{{.+}}, 0xR0000 - // NONATIVE-LLVM-NEXT: %[[#B:]] = zext i1 %[[#A]] to i8 - // NONATIVE-LLVM-NEXT: %[[#C:]] = xor i8 %[[#B]], 1 - // NONATIVE-LLVM-NEXT: %{{.+}} = zext i8 %[[#C]] to i32 + // NONATIVE-LLVM-NEXT: %[[#C:]] = xor i1 %[[#A]], true + // NONATIVE-LLVM-NEXT: %{{.+}} = zext i1 %[[#C]] to i32 // NATIVE-LLVM: %[[#A:]] = fcmp une bfloat %{{.+}}, 0xR0000 - // NATIVE-LLVM-NEXT: %[[#B:]] = zext i1 %[[#A]] to i8 - // NATIVE-LLVM-NEXT: %[[#C:]] = xor i8 %[[#B]], 1 - // NATIVE-LLVM-NEXT: %{{.+}} = zext i8 %[[#C]] to i32 + // NATIVE-LLVM-NEXT: %[[#C:]] = xor i1 %[[#A]], true + // NATIVE-LLVM-NEXT: %{{.+}} = zext i1 %[[#C]] to i32 h1 = -h1; // NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float diff --git a/clang/test/CIR/CodeGen/builtin-assume.cpp b/clang/test/CIR/CodeGen/builtin-assume.cpp index 8d7448a2724d..9a099c0c94f9 100644 --- a/clang/test/CIR/CodeGen/builtin-assume.cpp +++ b/clang/test/CIR/CodeGen/builtin-assume.cpp @@ -16,7 +16,7 @@ int test_assume(int x) { // CIR: } // LLVM: @_Z11test_assumei -// LLVM: %[[#cond:]] = trunc i8 %{{.+}} to i1 +// LLVM: %[[#cond:]] = icmp sgt i32 %{{.+}}, 0 // LLVM-NEXT: call void @llvm.assume(i1 %[[#cond]]) int test_assume_attr(int x) { @@ -32,7 +32,7 @@ int test_assume_attr(int x) { // CIR: } // LLVM: @_Z16test_assume_attri -// LLVM: %[[#cond:]] = trunc i8 %{{.+}} to i1 +// LLVM: %[[#cond:]] = icmp sgt i32 %{{.+}}, 0 // LLVM-NEXT: call void @llvm.assume(i1 %[[#cond]]) int test_assume_aligned(int *ptr) { diff --git a/clang/test/CIR/CodeGen/builtin-constant-p.c b/clang/test/CIR/CodeGen/builtin-constant-p.c index a8eb13adacfd..810806ec2443 100644 --- a/clang/test/CIR/CodeGen/builtin-constant-p.c +++ b/clang/test/CIR/CodeGen/builtin-constant-p.c @@ -20,8 +20,7 @@ int foo() { // LLVM: [[TMP1:%.*]] = alloca i32, i64 1 // LLVM: [[TMP2:%.*]] = load i32, ptr @a // LLVM: [[TMP3:%.*]] = call i1 @llvm.is.constant.i32(i32 [[TMP2]]) -// LLVM: [[TMP4:%.*]] = zext i1 [[TMP3]] to i8 -// LLVM: [[TMP5:%.*]] = zext i8 [[TMP4]] to i32 +// LLVM: [[TMP5:%.*]] = zext i1 [[TMP3]] to i32 // LLVM: store i32 [[TMP5]], ptr [[TMP1]] // LLVM: [[TMP6:%.*]] = load i32, ptr [[TMP1]] // LLVM: ret i32 [[TMP6]] diff --git a/clang/test/CIR/CodeGen/complex-arithmetic.c b/clang/test/CIR/CodeGen/complex-arithmetic.c index eddedc2d3a27..3630edfc6033 100644 --- a/clang/test/CIR/CodeGen/complex-arithmetic.c +++ b/clang/test/CIR/CodeGen/complex-arithmetic.c @@ -303,12 +303,9 @@ void mul() { // LLVM-FULL-NEXT: %[[#F:]] = fadd double %[[#C]], %[[#D]] // LLVM-FULL-NEXT: %[[#G:]] = insertvalue { double, double } undef, double %[[#E]], 0 // LLVM-FULL-NEXT: %[[#RES:]] = insertvalue { double, double } %[[#G]], double %[[#F]], 1 -// LLVM-FULL-NEXT: %[[#H:]] = fcmp une double %[[#E]], %[[#E]] -// LLVM-FULL-NEXT: %[[#COND:]] = zext i1 %[[#H]] to i8 -// LLVM-FULL-NEXT: %[[#I:]] = fcmp une double %[[#F]], %[[#F]] -// LLVM-FULL-NEXT: %[[#COND2:]] = zext i1 %[[#I]] to i8 -// LLVM-FULL-NEXT: %[[#J:]] = and i8 %[[#COND]], %[[#COND2]] -// LLVM-FULL-NEXT: %[[#COND3:]] = trunc i8 %[[#J]] to i1 +// LLVM-FULL-NEXT: %[[#COND:]] = fcmp une double %[[#E]], %[[#E]] +// LLVM-FULL-NEXT: %[[#COND2:]] = fcmp une double %[[#F]], %[[#F]] +// LLVM-FULL-NEXT: %[[#COND3:]] = and i1 %[[#COND]], %[[#COND2]] // LLVM-FULL: {{.+}}: // LLVM-FULL-NEXT: %{{.+}} = call { double, double } @__muldc3(double %[[#LHSR]], double %[[#LHSI]], double %[[#RHSR]], double %[[#RHSI]]) // LLVM-FULL-NEXT: br label %{{.+}} diff --git a/clang/test/CIR/CodeGen/complex-cast.c b/clang/test/CIR/CodeGen/complex-cast.c index 98afabd65340..5cadcf711a60 100644 --- a/clang/test/CIR/CodeGen/complex-cast.c +++ b/clang/test/CIR/CodeGen/complex-cast.c @@ -179,10 +179,8 @@ void complex_to_bool() { // LLVM: %[[#REAL:]] = extractvalue { double, double } %{{.+}}, 0 // LLVM-NEXT: %[[#IMAG:]] = extractvalue { double, double } %{{.+}}, 1 // LLVM-NEXT: %[[#RB:]] = fcmp une double %[[#REAL]], 0.000000e+00 -// LLVM-NEXT: %[[#RB2:]] = zext i1 %[[#RB]] to i8 // LLVM-NEXT: %[[#IB:]] = fcmp une double %[[#IMAG]], 0.000000e+00 -// LLVM-NEXT: %[[#IB2:]] = zext i1 %[[#IB]] to i8 -// LLVM-NEXT: %{{.+}} = or i8 %[[#RB2]], %[[#IB2]] +// LLVM-NEXT: %{{.+}} = or i1 %[[#RB]], %[[#IB]] // CIR-BEFORE: %{{.+}} = cir.cast(int_complex_to_bool, %{{.+}} : !cir.complex), !cir.bool @@ -196,10 +194,8 @@ void complex_to_bool() { // LLVM: %[[#REAL:]] = extractvalue { i32, i32 } %{{.+}}, 0 // LLVM-NEXT: %[[#IMAG:]] = extractvalue { i32, i32 } %{{.+}}, 1 // LLVM-NEXT: %[[#RB:]] = icmp ne i32 %[[#REAL]], 0 -// LLVM-NEXT: %[[#RB2:]] = zext i1 %[[#RB]] to i8 // LLVM-NEXT: %[[#IB:]] = icmp ne i32 %[[#IMAG]], 0 -// LLVM-NEXT: %[[#IB2:]] = zext i1 %[[#IB]] to i8 -// LLVM-NEXT: %{{.+}} = or i8 %[[#RB2]], %[[#IB2]] +// LLVM-NEXT: %{{.+}} = or i1 %[[#RB]], %[[#IB]] // CHECK: } diff --git a/clang/test/CIR/CodeGen/new-null.cpp b/clang/test/CIR/CodeGen/new-null.cpp index 1957d54873a0..4f46cbd51147 100644 --- a/clang/test/CIR/CodeGen/new-null.cpp +++ b/clang/test/CIR/CodeGen/new-null.cpp @@ -66,7 +66,7 @@ namespace test15 { // LLVM: %[[VAL_0:.*]] = alloca ptr, i64 1, align 8 // LLVM: store ptr %[[VAL_1:.*]], ptr %[[VAL_0]], align 8 // LLVM: %[[VAL_2:.*]] = load ptr, ptr %[[VAL_0]], align 8 - // LLVM: %[[VAL_3:.*]] = call ptr @_ZnwmPvb(i64 1, ptr %[[VAL_2]], i8 1) + // LLVM: %[[VAL_3:.*]] = call ptr @_ZnwmPvb(i64 1, ptr %[[VAL_2]], i1 true) // LLVM: %[[VAL_4:.*]] = icmp ne ptr %[[VAL_3]], null // LLVM: br i1 %[[VAL_4]], label %[[VAL_5:.*]], label %[[VAL_6:.*]] // LLVM: [[VAL_5]]: ; preds = %[[VAL_7:.*]] diff --git a/clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp b/clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp index 0127559bba65..63625236e42a 100644 --- a/clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp +++ b/clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp @@ -23,10 +23,8 @@ auto base_to_derived(int Base2::*ptr) -> int Derived::* { // LLVM: %[[#src:]] = load i64, ptr %{{.+}} // LLVM-NEXT: %[[#is_null:]] = icmp eq i64 %[[#src]], -1 - // LLVM-NEXT: %[[#is_null_bool:]] = zext i1 %[[#is_null]] to i8 // LLVM-NEXT: %[[#adjusted:]] = add i64 %[[#src]], 4 - // LLVM-NEXT: %[[#cond:]] = trunc i8 %[[#is_null_bool]] to i1 - // LLVM-NEXT: %{{.+}} = select i1 %[[#cond]], i64 -1, i64 %[[#adjusted]] + // LLVM-NEXT: %{{.+}} = select i1 %[[#is_null]], i64 -1, i64 %[[#adjusted]] } // CIR-LABEL: @_Z15derived_to_baseM7Derivedi @@ -37,10 +35,8 @@ auto derived_to_base(int Derived::*ptr) -> int Base2::* { // LLVM: %[[#src:]] = load i64, ptr %{{.+}} // LLVM-NEXT: %[[#is_null:]] = icmp eq i64 %[[#src]], -1 - // LLVM-NEXT: %[[#is_null_bool:]] = zext i1 %[[#is_null]] to i8 // LLVM-NEXT: %[[#adjusted:]] = sub i64 %[[#src]], 4 - // LLVM-NEXT: %[[#cond:]] = trunc i8 %[[#is_null_bool]] to i1 - // LLVM-NEXT: %9 = select i1 %[[#cond]], i64 -1, i64 %[[#adjusted]] + // LLVM-NEXT: %{{.+}} = select i1 %[[#is_null]], i64 -1, i64 %[[#adjusted]] } // CIR-LABEL: @_Z27base_to_derived_zero_offsetM5Base1i diff --git a/clang/test/CIR/CodeGen/static.cpp b/clang/test/CIR/CodeGen/static.cpp index 2ba42118dddb..88ff490c14ff 100644 --- a/clang/test/CIR/CodeGen/static.cpp +++ b/clang/test/CIR/CodeGen/static.cpp @@ -77,11 +77,11 @@ static Init __ioinit2(false); // LLVM: @_ZL9__ioinit2 = internal global %class.Init zeroinitializer // LLVM: @llvm.global_ctors = appending constant [2 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65536, ptr @__cxx_global_var_init, ptr null }, { i32, ptr, ptr } { i32 65536, ptr @__cxx_global_var_init.1, ptr null }] // LLVM: define internal void @__cxx_global_var_init() -// LLVM-NEXT: call void @_ZN4InitC1Eb(ptr @_ZL8__ioinit, i8 1) +// LLVM-NEXT: call void @_ZN4InitC1Eb(ptr @_ZL8__ioinit, i1 true) // LLVM-NEXT: call void @__cxa_atexit(ptr @_ZN4InitD1Ev, ptr @_ZL8__ioinit, ptr @__dso_handle) // LLVM-NEXT: ret void // LLVM: define internal void @__cxx_global_var_init.1() -// LLVM-NEXT: call void @_ZN4InitC1Eb(ptr @_ZL9__ioinit2, i8 0) +// LLVM-NEXT: call void @_ZN4InitC1Eb(ptr @_ZL9__ioinit2, i1 false) // LLVM-NEXT: call void @__cxa_atexit(ptr @_ZN4InitD1Ev, ptr @_ZL9__ioinit2, ptr @__dso_handle) // LLVM-NEXT: ret void // LLVM: define void @_GLOBAL__sub_I_static.cpp() diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index aaeb46e770b5..c628e3c2b46b 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -1140,8 +1140,8 @@ module { !s8i = !cir.int cir.func @no_reference_global() { // expected-error @below {{'cir.get_global' op 'str' does not reference a valid cir.global or cir.func}} - %0 = cir.get_global @str : !cir.ptr - cir.return + %0 = cir.get_global @str : !cir.ptr + cir.return } // ----- @@ -1458,7 +1458,7 @@ cir.global external @f = #cir.fp<0x7FC0000007FC0000007FC000000> : !cir.long_doub // ----- -// Long double with `double` semnatics should have a value that fits in a double. +// Long double with `double` semantics should have a value that fits in a double. // CHECK: cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double cir.global external @f = #cir.fp<0x7FC000007FC000000000> : !cir.long_double diff --git a/clang/test/CIR/Lowering/binop-overflow.cir b/clang/test/CIR/Lowering/binop-overflow.cir index 196771150dbe..6a2ef54c1501 100644 --- a/clang/test/CIR/Lowering/binop-overflow.cir +++ b/clang/test/CIR/Lowering/binop-overflow.cir @@ -11,22 +11,20 @@ module { cir.return %overflow : !cir.bool } - // MLIR: llvm.func @test_add_u32_u32_u32(%[[LHS:.+]]: i32, %[[RHS:.+]]: i32, %[[RES_PTR:.+]]: !llvm.ptr) -> i8 + // MLIR: llvm.func @test_add_u32_u32_u32(%[[LHS:.+]]: i32, %[[RHS:.+]]: i32, %[[RES_PTR:.+]]: !llvm.ptr) -> i1 // MLIR-NEXT: %[[#INTRIN_RET:]] = llvm.call_intrinsic "llvm.uadd.with.overflow.i32"(%[[LHS]], %[[RHS]]) : (i32, i32) -> !llvm.struct<(i32, i1)> // MLIR-NEXT: %[[#RES:]] = llvm.extractvalue %[[#INTRIN_RET]][0] : !llvm.struct<(i32, i1)> // MLIR-NEXT: %[[#OVFL:]] = llvm.extractvalue %[[#INTRIN_RET]][1] : !llvm.struct<(i32, i1)> - // MLIR-NEXT: %[[#OVFL_EXT:]] = llvm.zext %[[#OVFL]] : i1 to i8 // MLIR-NEXT: llvm.store %[[#RES]], %[[RES_PTR]] {{.*}} : i32, !llvm.ptr - // MLIR-NEXT: llvm.return %[[#OVFL_EXT]] : i8 + // MLIR-NEXT: llvm.return %[[#OVFL]] : i1 // MLIR-NEXT: } - // LLVM: define i8 @test_add_u32_u32_u32(i32 %[[#LHS:]], i32 %[[#RHS:]], ptr %[[#RES_PTR:]]) + // LLVM: define i1 @test_add_u32_u32_u32(i32 %[[#LHS:]], i32 %[[#RHS:]], ptr %[[#RES_PTR:]]) // LLVM-NEXT: %[[#INTRIN_RET:]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %[[#LHS]], i32 %[[#RHS]]) // LLVM-NEXT: %[[#RES:]] = extractvalue { i32, i1 } %[[#INTRIN_RET]], 0 // LLVM-NEXT: %[[#OVFL:]] = extractvalue { i32, i1 } %[[#INTRIN_RET]], 1 - // LLVM-NEXT: %[[#OVFL_EXT:]] = zext i1 %[[#OVFL]] to i8 // LLVM-NEXT: store i32 %[[#RES]], ptr %[[#RES_PTR]], align 4 - // LLVM-NEXT: ret i8 %[[#OVFL_EXT]] + // LLVM-NEXT: ret i1 %[[#OVFL]] // LLVM-NEXT: } cir.func @test_add_u32_u32_i32(%lhs: !u32i, %rhs: !u32i, %res: !cir.ptr) -> !cir.bool { @@ -35,7 +33,7 @@ module { cir.return %overflow : !cir.bool } - // MLIR: llvm.func @test_add_u32_u32_i32(%[[LHS:.+]]: i32, %[[RHS:.+]]: i32, %[[RES_PTR:.+]]: !llvm.ptr) -> i8 + // MLIR: llvm.func @test_add_u32_u32_i32(%[[LHS:.+]]: i32, %[[RHS:.+]]: i32, %[[RES_PTR:.+]]: !llvm.ptr) -> i1 // MLIR-NEXT: %[[#LHS_EXT:]] = llvm.zext %[[LHS]] : i32 to i33 // MLIR-NEXT: %[[#RHS_EXT:]] = llvm.zext %[[RHS]] : i32 to i33 // MLIR-NEXT: %[[#INTRIN_RET:]] = llvm.call_intrinsic "llvm.sadd.with.overflow.i33"(%[[#LHS_EXT]], %[[#RHS_EXT]]) : (i33, i33) -> !llvm.struct<(i33, i1)> @@ -45,12 +43,11 @@ module { // MLIR-NEXT: %[[#RES_EXT_2:]] = llvm.sext %[[#RES]] : i32 to i33 // MLIR-NEXT: %[[#TRUNC_OVFL:]] = llvm.icmp "ne" %[[#RES_EXT_2]], %[[#RES_EXT]] : i33 // MLIR-NEXT: %[[#OVFL:]] = llvm.or %[[#ARITH_OVFL]], %[[#TRUNC_OVFL]] : i1 - // MLIR-NEXT: %[[#OVFL_EXT:]] = llvm.zext %[[#OVFL]] : i1 to i8 // MLIR-NEXT: llvm.store %[[#RES]], %[[RES_PTR]] {{.*}} : i32, !llvm.ptr - // MLIR-NEXT: llvm.return %[[#OVFL_EXT]] : i8 + // MLIR-NEXT: llvm.return %[[#OVFL]] : i1 // MLIR-NEXT: } - // LLVM: define i8 @test_add_u32_u32_i32(i32 %[[#LHS:]], i32 %[[#RHS:]], ptr %[[#RES_PTR:]]) + // LLVM: define i1 @test_add_u32_u32_i32(i32 %[[#LHS:]], i32 %[[#RHS:]], ptr %[[#RES_PTR:]]) // LLVM-NEXT: %[[#LHS_EXT:]] = zext i32 %[[#LHS]] to i33 // LLVM-NEXT: %[[#RHS_EXT:]] = zext i32 %[[#RHS]] to i33 // LLVM-NEXT: %[[#INTRIN_RET:]] = call { i33, i1 } @llvm.sadd.with.overflow.i33(i33 %[[#LHS_EXT]], i33 %[[#RHS_EXT]]) @@ -60,8 +57,7 @@ module { // LLVM-NEXT: %[[#RES_EXT_2:]] = sext i32 %[[#RES]] to i33 // LLVM-NEXT: %[[#TRUNC_OVFL:]] = icmp ne i33 %[[#RES_EXT_2]], %[[#RES_EXT]] // LLVM-NEXT: %[[#OVFL:]] = or i1 %[[#ARITH_OVFL]], %[[#TRUNC_OVFL]] - // LLVM-NEXT: %[[#OVFL_EXT:]] = zext i1 %[[#OVFL]] to i8 // LLVM-NEXT: store i32 %[[#RES]], ptr %[[#RES_PTR]], align 4 - // LLVM-NEXT: ret i8 %[[#OVFL_EXT]] + // LLVM-NEXT: ret i1 %[[#OVFL]] // LLVM-NEXT: } } diff --git a/clang/test/CIR/Lowering/bool.cir b/clang/test/CIR/Lowering/bool.cir index 2d3fc2d8590b..848b552f897a 100644 --- a/clang/test/CIR/Lowering/bool.cir +++ b/clang/test/CIR/Lowering/bool.cir @@ -16,10 +16,11 @@ module { cir.return } // MLIR: llvm.func @foo() -// MLIR-DAG: = llvm.mlir.constant(1 : i8) : i8 -// MLIR-DAG: [[Value:%[a-z0-9]+]] = llvm.mlir.constant(1 : index) : i64 -// MLIR-DAG: = llvm.alloca [[Value]] x i8 {alignment = 1 : i64} : (i64) -> !llvm.ptr -// MLIR-DAG: llvm.store %0, %2 {{.*}} : i8, !llvm.ptr +// MLIR-DAG: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1 +// MLIR-DAG: %[[VALUE:.*]] = llvm.mlir.constant(1 : index) : i64 +// MLIR-DAG: %[[ADDR:.*]] = llvm.alloca %[[VALUE]] x i8 {alignment = 1 : i64} : (i64) -> !llvm.ptr +// MLIR-DAG: %[[TRUE_EXT:.*]] = llvm.zext %[[TRUE]] : i1 to i8 +// MLIR-DAG: llvm.store %[[TRUE_EXT]], %[[ADDR]] {{.*}} : i8, !llvm.ptr // MLIR-NEXT: llvm.return // LLVM: define void @foo() diff --git a/clang/test/CIR/Lowering/branch.cir b/clang/test/CIR/Lowering/branch.cir index a99a217f18da..0daea329f4b8 100644 --- a/clang/test/CIR/Lowering/branch.cir +++ b/clang/test/CIR/Lowering/branch.cir @@ -13,25 +13,23 @@ cir.func @foo(%arg0: !cir.bool) -> !s32i { } // MLIR: module { -// MLIR-NEXT: llvm.func @foo(%arg0: i8) -> i32 -// MLIR-NEXT: %0 = llvm.trunc %arg0 : i8 to i1 -// MLIR-NEXT: llvm.cond_br %0, ^bb1, ^bb2 +// MLIR-NEXT: llvm.func @foo(%arg0: i1) -> i32 +// MLIR-NEXT: llvm.cond_br %arg0, ^bb1, ^bb2 // MLIR-NEXT: ^bb1: // pred: ^bb0 -// MLIR-NEXT: %1 = llvm.mlir.constant(1 : i32) : i32 -// MLIR-NEXT: llvm.return %1 : i32 +// MLIR-NEXT: %0 = llvm.mlir.constant(1 : i32) : i32 +// MLIR-NEXT: llvm.return %0 : i32 // MLIR-NEXT: ^bb2: // pred: ^bb0 -// MLIR-NEXT: %2 = llvm.mlir.constant(0 : i32) : i32 -// MLIR-NEXT: llvm.return %2 : i32 +// MLIR-NEXT: %1 = llvm.mlir.constant(0 : i32) : i32 +// MLIR-NEXT: llvm.return %1 : i32 // MLIR-NEXT: } // MLIR-NEXT: } -// LLVM: define i32 @foo(i8 %0) -// LLVM-NEXT: %2 = trunc i8 %0 to i1 -// LLVM-NEXT: br i1 %2, label %3, label %4 +// LLVM: define i32 @foo(i1 %0) +// LLVM-NEXT: br i1 %0, label %2, label %3 // LLVM-EMPTY: -// LLVM-NEXT: 3: ; preds = %1 +// LLVM-NEXT: 2: ; preds = %1 // LLVM-NEXT: ret i32 1 // LLVM-EMPTY: -// LLVM-NEXT: 4: ; preds = %1 +// LLVM-NEXT: 3: ; preds = %1 // LLVM-NEXT: ret i32 0 // LLVM-NEXT: } diff --git a/clang/test/CIR/Lowering/brcond.cir b/clang/test/CIR/Lowering/brcond.cir index 262e0a8f868b..19e778cef823 100644 --- a/clang/test/CIR/Lowering/brcond.cir +++ b/clang/test/CIR/Lowering/brcond.cir @@ -4,40 +4,39 @@ !s32i = !cir.int #fn_attr = #cir, nothrow = #cir.nothrow, optnone = #cir.optnone})> module { cir.func no_proto @test() -> !cir.bool extra(#fn_attr) { - %0 = cir.const #cir.int<0> : !s32i - %1 = cir.cast(int_to_bool, %0 : !s32i), !cir.bool - cir.br ^bb1 + %0 = cir.const #cir.int<0> : !s32i + %1 = cir.cast(int_to_bool, %0 : !s32i), !cir.bool + cir.br ^bb1 ^bb1: - cir.brcond %1 ^bb2, ^bb3 + cir.brcond %1 ^bb2, ^bb3 ^bb2: - cir.return %1 : !cir.bool + cir.return %1 : !cir.bool ^bb3: - cir.br ^bb4 + cir.br ^bb4 ^bb4: - cir.return %1 : !cir.bool - } + cir.return %1 : !cir.bool + } } // MLIR: {{.*}} = llvm.mlir.constant(0 : i32) : i32 // MLIR-NEXT: {{.*}} = llvm.mlir.constant(0 : i32) : i32 // MLIR-NEXT: {{.*}} = llvm.icmp "ne" {{.*}}, {{.*}} : i32 -// MLIR-NEXT: {{.*}} = llvm.zext {{.*}} : i1 to i8 // MLIR-NEXT: llvm.br ^bb1 // MLIR-NEXT: ^bb1: // MLIR-NEXT: llvm.cond_br {{.*}}, ^bb2, ^bb3 // MLIR-NEXT: ^bb2: -// MLIR-NEXT: llvm.return {{.*}} : i8 +// MLIR-NEXT: llvm.return {{.*}} : i1 // MLIR-NEXT: ^bb3: // MLIR-NEXT: llvm.br ^bb4 // MLIR-NEXT: ^bb4: -// MLIR-NEXT: llvm.return {{.*}} : i8 +// MLIR-NEXT: llvm.return {{.*}} : i1 // LLVM: br label {{.*}} // LLVM: 1: // LLVM: br i1 false, label {{.*}}, label {{.*}} // LLVM: 2: -// LLVM: ret i8 0 +// LLVM: ret i1 false // LLVM: 3: // LLVM: br label {{.*}} // LLVM: 4: -// LLVM: ret i8 0 +// LLVM: ret i1 false diff --git a/clang/test/CIR/Lowering/cast.cir b/clang/test/CIR/Lowering/cast.cir index e100e0c2f07e..7b731794f1fa 100644 --- a/clang/test/CIR/Lowering/cast.cir +++ b/clang/test/CIR/Lowering/cast.cir @@ -51,7 +51,6 @@ module { %33 = cir.cast(int_to_bool, %arg1 : !s32i), !cir.bool // CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[#CMP:]] = llvm.icmp "ne" %arg1, %[[#ZERO]] : i32 - // CHECK: %{{.+}} = llvm.zext %[[#CMP]] : i1 to i8 // Pointer casts. cir.store %16, %6 : !s64i, !cir.ptr @@ -91,9 +90,22 @@ module { %2 = cir.load %0 : !cir.ptr, !cir.bool %3 = cir.cast(bool_to_int, %2 : !cir.bool), !u8i // CHECK: %[[LOAD_BOOL:.*]] = llvm.load %{{.*}} : !llvm.ptr -> i8 - // CHECK: %{{.*}} = llvm.bitcast %[[LOAD_BOOL]] : i8 to i8 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[LOAD_BOOL]] : i8 to i1 + // CHECK: %[[EXT:.*]] = llvm.zext %[[TRUNC]] : i1 to i8 cir.store %3, %1 : !u8i, !cir.ptr cir.return } + + // Test cases where the memory type is not the same as the source type. + cir.func @testArrayToPtrDecay() { + // CHECK-LABEL: llvm.func @testArrayToPtrDecay() + %null_bool_array = cir.const #cir.ptr : !cir.ptr> + %bool_array_decay = cir.cast(array_to_ptrdecay, %null_bool_array : !cir.ptr>), !cir.ptr + // CHECK: = llvm.getelementptr %{{.*}}[0] : (!llvm.ptr) -> !llvm.ptr, i8 + %res = cir.load %bool_array_decay : !cir.ptr, !cir.bool + // CHECK-NEXT: %[[BOOL_LOAD:.+]] = llvm.load %{{.*}} {{.*}} : !llvm.ptr -> i8 + // CHECK-NEXT: = llvm.trunc %[[BOOL_LOAD]] : i8 to i1 + cir.return + } } diff --git a/clang/test/CIR/Lowering/const-array.cir b/clang/test/CIR/Lowering/const-array.cir index 41cfbad3daba..84a21665bffd 100644 --- a/clang/test/CIR/Lowering/const-array.cir +++ b/clang/test/CIR/Lowering/const-array.cir @@ -1,11 +1,16 @@ // RUN: cir-translate %s -cir-to-llvmir --disable-cc-lowering -o - | FileCheck %s -check-prefix=LLVM !u8i = !cir.int +#false = #cir.bool : !cir.bool +#true = #cir.bool : !cir.bool module { cir.global "private" internal @normal_url_char = #cir.const_array<[#cir.int<0> : !u8i, #cir.int<1> : !u8i], trailing_zeros> : !cir.array // LLVM: @normal_url_char = internal global [4 x i8] c"\00\01\00\00" + cir.global "private" internal @g_const_bool_arr = #cir.const_array<[#true, #false, #true, #false]> : !cir.array + // LLVM: @g_const_bool_arr = internal global [4 x i8] c"\01\00\01\00" + cir.func @c0() -> !cir.ptr> { %0 = cir.get_global @normal_url_char : !cir.ptr> cir.return %0 : !cir.ptr> diff --git a/clang/test/CIR/Lowering/const.cir b/clang/test/CIR/Lowering/const.cir index ae78b8387fc5..7d9b495f784e 100644 --- a/clang/test/CIR/Lowering/const.cir +++ b/clang/test/CIR/Lowering/const.cir @@ -78,8 +78,9 @@ module { // CHECK: llvm.func @testInitArrWithBool() // CHECK: [[ARR:%.*]] = llvm.mlir.undef : !llvm.array<1 x i8> - // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(1 : i8) : i8 - // CHECK: {{.*}} = llvm.insertvalue [[TRUE]], [[ARR]][0] : !llvm.array<1 x i8> - // CHECL: llvm.return + // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1 + // CHECK: [[TRUE_EXT:%.*]] = llvm.zext [[TRUE]] : i1 to i8 + // CHECK: {{.*}} = llvm.insertvalue [[TRUE_EXT]], [[ARR]][0] : !llvm.array<1 x i8> + // CHECK: llvm.return } diff --git a/clang/test/CIR/Lowering/loadstorealloca.cir b/clang/test/CIR/Lowering/loadstorealloca.cir index 5764d5afc8f5..85f714dc6b51 100644 --- a/clang/test/CIR/Lowering/loadstorealloca.cir +++ b/clang/test/CIR/Lowering/loadstorealloca.cir @@ -18,7 +18,7 @@ module { %2 = cir.load volatile %0 : !cir.ptr, !u32i cir.return %2 : !u32i } -} + // MLIR: module { // MLIR-NEXT: func @foo() -> i32 @@ -37,3 +37,21 @@ module { // MLIR-NEXT: llvm.store volatile %2, %1 {{.*}}: i32, !llvm.ptr // MLIR-NEXT: %3 = llvm.load volatile %1 {alignment = 4 : i64} : !llvm.ptr -> i32 // MLIR-NEXT: return %3 : i32 + + cir.func @test_bool_memory_lowering() { + // MLIR-LABEL: @test_bool_memory_lowering + %0 = cir.alloca !cir.bool, !cir.ptr, ["x", init] {alignment = 1 : i64} + // MLIR: %[[VAR:.*]] = llvm.alloca %{{.*}} x i8 + %1 = cir.const #cir.bool : !cir.bool + // MLIR: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1 + cir.store %1, %0 : !cir.bool, !cir.ptr + // MLIR: %[[TRUE_EXT:.*]] = llvm.zext %[[TRUE]] : i1 to i8 + // MLIR: llvm.store %[[TRUE_EXT]], %[[VAR]] {alignment = 1 : i64} : i8, !llvm.ptr + %2 = cir.load %0 : !cir.ptr, !cir.bool + // MLIR: %[[LOAD_VAL:.*]] = llvm.load %[[VAR]] {alignment = 1 : i64} : !llvm.ptr -> i8 + // MLIR: %[[LOAD_SCALAR:.*]] = llvm.trunc %[[LOAD_VAL]] : i8 to i1 + %3 = cir.cast(bool_to_int, %2 : !cir.bool), !u32i + // MLIR: %[[CAST_VAL:.*]] = llvm.zext %[[LOAD_SCALAR]] : i1 to i32 + cir.return + } +} diff --git a/clang/test/CIR/Lowering/ptrstride.cir b/clang/test/CIR/Lowering/ptrstride.cir index b5df897d2b0e..648bd0e32da5 100644 --- a/clang/test/CIR/Lowering/ptrstride.cir +++ b/clang/test/CIR/Lowering/ptrstride.cir @@ -2,6 +2,8 @@ // RUN: FileCheck %s --input-file=%t.mlir -check-prefix=MLIR !s32i = !cir.int +!u64i = !cir.int + module { cir.func @f(%arg0: !cir.ptr) { %0 = cir.alloca !cir.ptr, !cir.ptr>, ["a", init] {alignment = 8 : i64} @@ -16,6 +18,11 @@ module { %3 = cir.ptr_stride(%arg0 : !cir.ptr, %2 : !s32i), !cir.ptr cir.return } + + cir.func @bool_stride(%arg0: !cir.ptr, %2 : !u64i) { + %3 = cir.ptr_stride(%arg0 : !cir.ptr, %2 : !u64i), !cir.ptr + cir.return + } } // MLIR-LABEL: @f @@ -32,3 +39,6 @@ module { // MLIR-LABEL: @g // MLIR: %0 = llvm.sext %arg1 : i32 to i64 // MLIR-NEXT: llvm.getelementptr %arg0[%0] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + +// MLIR-LABEL: @bool_stride +// MLIR: llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, i8 diff --git a/clang/test/CIR/Lowering/select.cir b/clang/test/CIR/Lowering/select.cir index 1ac56496e138..71ca79a390e8 100644 --- a/clang/test/CIR/Lowering/select.cir +++ b/clang/test/CIR/Lowering/select.cir @@ -9,9 +9,8 @@ module { cir.return %0 : !s32i } - // LLVM: define i32 @select_int(i8 %[[#COND:]], i32 %[[#TV:]], i32 %[[#FV:]]) - // LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1 - // LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i32 %[[#TV]], i32 %[[#FV]] + // LLVM: define i32 @select_int(i1 %[[#COND:]], i32 %[[#TV:]], i32 %[[#FV:]]) + // LLVM-NEXT: %[[#RES:]] = select i1 %[[#COND]], i32 %[[#TV]], i32 %[[#FV]] // LLVM-NEXT: ret i32 %[[#RES]] // LLVM-NEXT: } @@ -20,10 +19,9 @@ module { cir.return %0 : !cir.bool } - // LLVM: define i8 @select_bool(i8 %[[#COND:]], i8 %[[#TV:]], i8 %[[#FV:]]) - // LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1 - // LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i8 %[[#TV]], i8 %[[#FV]] - // LLVM-NEXT: ret i8 %[[#RES]] + // LLVM: define i1 @select_bool(i1 %[[#COND:]], i1 %[[#TV:]], i1 %[[#FV:]]) + // LLVM-NEXT: %[[#RES:]] = select i1 %[[#COND]], i1 %[[#TV]], i1 %[[#FV]] + // LLVM-NEXT: ret i1 %[[#RES]] // LLVM-NEXT: } cir.func @logical_and(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool { @@ -32,9 +30,9 @@ module { cir.return %1 : !cir.bool } - // LLVM: define i8 @logical_and(i8 %[[#ARG0:]], i8 %[[#ARG1:]]) - // LLVM-NEXT: %[[#RES:]] = and i8 %[[#ARG0]], %[[#ARG1]] - // LLVM-NEXT: ret i8 %[[#RES]] + // LLVM: define i1 @logical_and(i1 %[[#ARG0:]], i1 %[[#ARG1:]]) + // LLVM-NEXT: %[[#RES:]] = and i1 %[[#ARG0]], %[[#ARG1]] + // LLVM-NEXT: ret i1 %[[#RES]] // LLVM-NEXT: } cir.func @logical_or(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool { @@ -43,8 +41,8 @@ module { cir.return %1 : !cir.bool } - // LLVM: define i8 @logical_or(i8 %[[#ARG0:]], i8 %[[#ARG1:]]) - // LLVM-NEXT: %[[#RES:]] = or i8 %[[#ARG0]], %[[#ARG1]] - // LLVM-NEXT: ret i8 %[[#RES]] + // LLVM: define i1 @logical_or(i1 %[[#ARG0:]], i1 %[[#ARG1:]]) + // LLVM-NEXT: %[[#RES:]] = or i1 %[[#ARG0]], %[[#ARG1]] + // LLVM-NEXT: ret i1 %[[#RES]] // LLVM-NEXT: } } diff --git a/clang/test/CIR/Lowering/struct.cir b/clang/test/CIR/Lowering/struct.cir index c89a58a9772e..e612dcd66efd 100644 --- a/clang/test/CIR/Lowering/struct.cir +++ b/clang/test/CIR/Lowering/struct.cir @@ -10,6 +10,8 @@ !ty_S2_ = !cir.struct !ty_S3_ = !cir.struct +!struct_with_bool = !cir.struct + module { cir.func @test() { %1 = cir.alloca !ty_S, !cir.ptr, ["x"] {alignment = 4 : i64} @@ -93,4 +95,21 @@ module { // CHECK: "llvm.intr.memcpy"(%[[#SB]], %[[#SA]], %[[#SIZE]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () cir.return } + + // Verify that boolean fields are lowered to i8 and that the correct type is inserted during initialization. + cir.global external @struct_with_bool = #cir.const_struct<{#cir.int<1> : !u32i, #cir.bool : !cir.bool}> : !struct_with_bool + // CHECK: llvm.mlir.global external @struct_with_bool() {addr_space = 0 : i32} : !llvm.struct<"struct.struct_with_bool", (i32, i8)> { + // CHECK: %[[FALSE:.+]] = llvm.mlir.constant(false) : i1 + // CHECK-NEXT: %[[FALSE_MEM:.+]] = llvm.zext %[[FALSE]] : i1 to i8 + // CHECK-NEXT: = llvm.insertvalue %[[FALSE_MEM]], %{{.+}}[1] : !llvm.struct<"struct.struct_with_bool", (i32, i8)> + + cir.func @test_struct_with_bool() { + // CHECK-LABEL: llvm.func @test_struct_with_bool() + %0 = cir.alloca !struct_with_bool, !cir.ptr, ["a"] {alignment = 4 : i64} + %1 = cir.get_member %0[1] {name = "b"} : !cir.ptr -> !cir.ptr + // CHECK: %[[BOOL_MEMBER_PTR:.+]] = llvm.getelementptr %{{.*}}[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"struct.struct_with_bool", (i32, i8)> + %2 = cir.load %1 : !cir.ptr, !cir.bool + // CHECK: = llvm.load %[[BOOL_MEMBER_PTR]] {{.*}} : !llvm.ptr -> i8 + cir.return + } } diff --git a/clang/test/CIR/Lowering/unary-not.cir b/clang/test/CIR/Lowering/unary-not.cir index 86a7405bd0ee..35cd54f3df78 100644 --- a/clang/test/CIR/Lowering/unary-not.cir +++ b/clang/test/CIR/Lowering/unary-not.cir @@ -31,18 +31,16 @@ module { %3 = cir.cast(float_to_bool, %2 : !cir.float), !cir.bool // MLIR: %[[#F_ZERO:]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // MLIR: %[[#F_BOOL:]] = llvm.fcmp "une" %{{.+}}, %[[#F_ZERO]] : f32 - // MLIR: %[[#F_ZEXT:]] = llvm.zext %[[#F_BOOL]] : i1 to i8 %4 = cir.unary(not, %3) : !cir.bool, !cir.bool - // MLIR: %[[#F_ONE:]] = llvm.mlir.constant(1 : i8) : i8 - // MLIR: = llvm.xor %[[#F_ZEXT]], %[[#F_ONE]] : i8 + // MLIR: %[[#F_ONE:]] = llvm.mlir.constant(true) : i1 + // MLIR: = llvm.xor %[[#F_BOOL]], %[[#F_ONE]] : i1 %5 = cir.load %1 : !cir.ptr, !cir.double %6 = cir.cast(float_to_bool, %5 : !cir.double), !cir.bool // MLIR: %[[#D_ZERO:]] = llvm.mlir.constant(0.000000e+00 : f64) : f64 // MLIR: %[[#D_BOOL:]] = llvm.fcmp "une" %{{.+}}, %[[#D_ZERO]] : f64 - // MLIR: %[[#D_ZEXT:]] = llvm.zext %[[#D_BOOL]] : i1 to i8 %7 = cir.unary(not, %6) : !cir.bool, !cir.bool - // MLIR: %[[#D_ONE:]] = llvm.mlir.constant(1 : i8) : i8 - // MLIR: = llvm.xor %[[#D_ZEXT]], %[[#D_ONE]] : i8 + // MLIR: %[[#D_ONE:]] = llvm.mlir.constant(true) : i1 + // MLIR: = llvm.xor %[[#D_BOOL]], %[[#D_ONE]] : i1 cir.return } @@ -60,10 +58,9 @@ module { // MLIR: %[[#INT:]] = llvm.load %{{.+}} : !llvm.ptr // MLIR: %[[#IZERO:]] = llvm.mlir.constant(0 : i32) : i32 // MLIR: %[[#ICMP:]] = llvm.icmp "ne" %[[#INT]], %[[#IZERO]] : i32 - // MLIR: %[[#IEXT:]] = llvm.zext %[[#ICMP]] : i1 to i8 - // MLIR: %[[#IONE:]] = llvm.mlir.constant(1 : i8) : i8 - // MLIR: %[[#IXOR:]] = llvm.xor %[[#IEXT]], %[[#IONE]] : i8 - // MLIR: = llvm.zext %[[#IXOR]] : i8 to i32 + // MLIR: %[[#IONE:]] = llvm.mlir.constant(true) : i1 + // MLIR: %[[#IXOR:]] = llvm.xor %[[#ICMP]], %[[#IONE]] : i1 + // MLIR: = llvm.zext %[[#IXOR]] : i1 to i32 %17 = cir.load %3 : !cir.ptr, !cir.float %18 = cir.cast(float_to_bool, %17 : !cir.float), !cir.bool @@ -72,10 +69,9 @@ module { // MLIR: %[[#FLOAT:]] = llvm.load %{{.+}} : !llvm.ptr // MLIR: %[[#FZERO:]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // MLIR: %[[#FCMP:]] = llvm.fcmp "une" %[[#FLOAT]], %[[#FZERO]] : f32 - // MLIR: %[[#FEXT:]] = llvm.zext %[[#FCMP]] : i1 to i8 - // MLIR: %[[#FONE:]] = llvm.mlir.constant(1 : i8) : i8 - // MLIR: %[[#FXOR:]] = llvm.xor %[[#FEXT]], %[[#FONE]] : i8 - // MLIR: = llvm.zext %[[#FXOR]] : i8 to i32 + // MLIR: %[[#FONE:]] = llvm.mlir.constant(true) : i1 + // MLIR: %[[#FXOR:]] = llvm.xor %[[#FCMP]], %[[#FONE]] : i1 + // MLIR: = llvm.zext %[[#FXOR]] : i1 to i32 cir.return } diff --git a/clang/test/CIR/Lowering/unions.cir b/clang/test/CIR/Lowering/unions.cir index fe56e2af7527..445ef463ef2d 100644 --- a/clang/test/CIR/Lowering/unions.cir +++ b/clang/test/CIR/Lowering/unions.cir @@ -25,9 +25,10 @@ module { %5 = cir.const #true %6 = cir.get_member %arg0[0] {name = "b"} : !cir.ptr -> !cir.ptr cir.store %5, %6 : !cir.bool, !cir.ptr - // CHECK: %[[#VAL:]] = llvm.mlir.constant(1 : i8) : i8 + // CHECK: %[[#TRUE:]] = llvm.mlir.constant(true) : i1 // The bitcast it just to bypass the type checker. It will be replaced by an opaque pointer. // CHECK: %[[#ADDR:]] = llvm.bitcast %{{.+}} : !llvm.ptr + // CHECK: %[[#VAL:]] = llvm.zext %[[#TRUE]] : i1 to i8 // CHECK: llvm.store %[[#VAL]], %[[#ADDR]] {{.*}}: i8, !llvm.ptr // Should load direclty from the union's base address.